首页 > 代码库 > TF-搞不懂的TF矩阵加法
TF-搞不懂的TF矩阵加法
看谷歌的demo mnist,卷积后加偏执量的代码
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
其中的x_image的维数是[-1, 28, 28, 1],W_conv1的维数是[5, 5, 1, 32], b的维数是[32]
conv2d对x_image和W_conv1进行卷积,结果为[-1, 28, 28, 32],结果就是:
[-1, 28, 28, 32]和[32]的加法。
完全搞不清为什么[-1, 28, 28, 32]和[32]两个完全不同维数可以做加法?而且加出的结果还是[-1, 28, 28, 32]?
于是做了下面的测试:
sess = tf.InteractiveSession() test1 = tf.ones([1,2,2,3],tf.float32) b1 = tf.ones([3]) re1 = test1 + b1 print("shap3={},eval=\n{}".format(b1.shape, b1.eval())) print("shap4={},eval=\n{}".format(test1.shape, test1.eval())) print("shap5={},eval=\n{}".format(re1.shape, re1.eval())) test1 = tf.ones([1,2,2,3],tf.float32) b1 = tf.ones([1,1,1,1]) re1 = test1 + b1 print("shap6={},eval=\n{}".format(b1.shape, b1.eval())) print("shap7={},eval=\n{}".format(test1.shape, test1.eval())) print("shap8={},eval=\n{}".format(re1.shape, re1.eval())) test1 = tf.ones([1,2,2,3],tf.float32) b1 = tf.ones([1,1,1,3]) re1 = test1 + b1 print("shap9 ={},eval=\n{}".format(b1.shape, b1.eval())) print("shap10={},eval=\n{}".format(test1.shape, test1.eval())) print("shap11={},eval=\n{}".format(re1.shape, re1.eval())) test1 = tf.ones([1,2,2,3],tf.float32) b1 = tf.ones([1]) re1 = test1 + b1 print("shap12={},eval=\n{}".format(b1.shape, b1.eval())) print("shap13={},eval=\n{}".format(test1.shape, test1.eval())) print("shap14={},eval=\n{}".format(re1.shape, re1.eval()))
test1 = tf.ones([1,2,2,3],tf.float32)
alist = [[[[ 1, 1, 1.],
[ 0, 0, 0.]],
[[ 1, 1, 1.],
[ 0, 0, 0.]]]]
b1 = tf.constant(alist)
re1 = test1 + b1
print("shap15={},eval=\n{}".format(b1.shape, b1.eval()))
print("shap16={},eval=\n{}".format(test1.shape, test1.eval()))
print("shap17={},eval=\n{}".format(re1.shape, re1.eval()))
结果为
shap3=(3,),eval= [ 1. 1. 1.] shap4=(1, 2, 2, 3),eval= [[[[ 1. 1. 1.] [ 1. 1. 1.]] [[ 1. 1. 1.] [ 1. 1. 1.]]]] shap5=(1, 2, 2, 3),eval= [[[[ 2. 2. 2.] [ 2. 2. 2.]] [[ 2. 2. 2.] [ 2. 2. 2.]]]] shap6=(1, 1, 1, 1),eval= [[[[ 1.]]]] shap7=(1, 2, 2, 3),eval= [[[[ 1. 1. 1.] [ 1. 1. 1.]] [[ 1. 1. 1.] [ 1. 1. 1.]]]] shap8=(1, 2, 2, 3),eval= [[[[ 2. 2. 2.] [ 2. 2. 2.]] [[ 2. 2. 2.] [ 2. 2. 2.]]]] shap9 =(1, 1, 1, 3),eval= [[[[ 1. 1. 1.]]]] shap10=(1, 2, 2, 3),eval= [[[[ 1. 1. 1.] [ 1. 1. 1.]] [[ 1. 1. 1.] [ 1. 1. 1.]]]] shap11=(1, 2, 2, 3),eval= [[[[ 2. 2. 2.] [ 2. 2. 2.]] [[ 2. 2. 2.] [ 2. 2. 2.]]]] shap12=(1,),eval= [ 1.] shap13=(1, 2, 2, 3),eval= [[[[ 1. 1. 1.] [ 1. 1. 1.]] [[ 1. 1. 1.] [ 1. 1. 1.]]]] shap14=(1, 2, 2, 3),eval= [[[[ 2. 2. 2.] [ 2. 2. 2.]] [[ 2. 2. 2.] [ 2. 2. 2.]]]]
shap15=(1, 2, 2, 3),eval=
[[[[ 1. 1. 1.]
[ 0. 0. 0.]]
[[ 1. 1. 1.]
[ 0. 0. 0.]]]]
shap16=(1, 2, 2, 3),eval=
[[[[ 1. 1. 1.]
[ 1. 1. 1.]]
[[ 1. 1. 1.]
[ 1. 1. 1.]]]]
shap17=(1, 2, 2, 3),eval=
[[[[ 2. 2. 2.]
[ 1. 1. 1.]]
[[ 2. 2. 2.]
[ 1. 1. 1.]]]]
这个结果说明了什么呢?说明张量加法时,维数不等时会自动扩充,用存在的数字填充。
比如下面这个[4, 3, 2, 3]的矩阵A,
我们把A加上[1, 2, 3]结果为
[[[[1 2 3]
[2 3 4]]
[[3 4 5]
[4 5 6]]
[[5 6 7]
[6 7 8]]]
[[[1 2 3]
[2 3 4]]
[[3 4 5]
[4 5 6]]
[[5 6 7]
[6 7 8]]]
[[[1 2 3]
[2 3 4]]
[[3 4 5]
[4 5 6]]
[[5 6 7]
[6 7 8]]]
[[[1 2 3]
[2 3 4]]
[[3 4 5]
[4 5 6]]
[[5 6 7]
[6 7 8]]]]
TF-搞不懂的TF矩阵加法