首页 > 代码库 > 『TensorFlow』迁移学习_他山之石,可以攻玉

『TensorFlow』迁移学习_他山之石,可以攻玉

目的:

使用google已经训练好的模型,将最后的全连接层修改为我们自己的全连接层,将原有的1000分类分类器修改为我们自己的5分类分类器,利用原有模型的特征提取能力实现我们自己数据对应模型的快速训练。实际中对于一个陌生的数据集,原有模型经过不高的迭代次数即可获得很好的准确率。

实战:

实机文件夹如下,两个压缩文件可以忽略:

技术分享

花朵图片数据下载:

1 curl -O http://download.tensorflow.org/example_images/flower_photos.tgz

已经训练好的Inception-v3的1000分类模型下载:

1 wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip

迁移学习代码如下:

  1 import glob
  2 import os.path
  3 import random
  4 import numpy as np
  5 import tensorflow as tf
  6 from tensorflow.python.platform import gfile
  7 
  8 ‘‘‘模型及样本路径设置‘‘‘
  9 
 10 BOTTLENECK_TENSOR_SIZE = 2048                          # 瓶颈层节点个数
 11 BOTTLENECK_TENSOR_NAME = pool_3/_reshape:0           # 瓶颈层输出张量名称
 12 JPEG_DATA_TENSOR_NAME  = DecodeJpeg/contents:0       # 输入层张量名称
 13 
 14 MODEL_DIR  = ./inception_dec_2015                    # 模型存放文件夹
 15 MODEL_FILE = tensorflow_inception_graph.pb           # 模型名
 16 
 17 CACHE_DIR  = ./bottleneck                            # 瓶颈输出中转文件夹
 18 INPUT_DATA = http://www.mamicode.com/./flower_photos                         # 数据文件夹
 19 
 20 VALIDATION_PERCENTAGE = 10                             # 验证用数据百分比
 21 TEST_PERCENTAGE       = 10                             # 测试用数据百分比
 22 
 23 ‘‘‘新添加神经网络部参数设置‘‘‘
 24 
 25 LEARNING_RATE = 0.01
 26 STEP          = 4000
 27 BATCH         = 100
 28 
 29 def creat_image_lists(validation_percentage,testing_percentage):
 30     ‘‘‘
 31     将图片(无路径文件名)信息保存在字典中
 32     :param validation_percentage: 验证数据百分比 
 33     :param testing_percentage:    测试数据百分比
 34     :return:                      字典{标签:{文件夹:str,训练:[],验证:[],测试:[]},...}
 35     ‘‘‘
 36     result = {}
 37     sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
 38     # 由于os.walk()列表第一个是‘./‘,所以排除
 39     is_root_dir = True            #<-----
 40     # 遍历各个label文件夹
 41     for sub_dir in sub_dirs:
 42         if is_root_dir:           #<-----
 43             is_root_dir = False
 44             continue
 45 
 46         extensions = [jpg, jpeg, JPG, JPEG]
 47         file_list  = []
 48         dir_name   = os.path.basename(sub_dir)
 49         # 遍历各个可能的文件尾缀
 50         for extension in extensions:
 51             # file_glob = os.path.join(INPUT_DATA,dir_name,‘*.‘+extension)
 52             file_glob = os.path.join(sub_dir, *. + extension)
 53             file_list.extend(glob.glob(file_glob))      # 匹配并收集路径&文件名
 54             # print(file_glob,‘\n‘,glob.glob(file_glob))
 55         if not file_list: continue
 56 
 57         label_name = dir_name.lower()                   # 生成label,实际就是小写文件夹名
 58 
 59         # 初始化各个路径&文件收集list
 60         training_images   = []
 61         testing_images    = []
 62         validation_images = []
 63 
 64         # 去路径,只保留文件名
 65         for file_name in file_list:
 66             base_name = os.path.basename(file_name)
 67 
 68             # 随机划分数据给验证和测试
 69             chance = np.random.randint(100)
 70             if chance < validation_percentage:
 71                 validation_images.append(base_name)
 72             elif chance < (validation_percentage + testing_percentage):
 73                 testing_images.append(base_name)
 74             else:
 75                 training_images.append(base_name)
 76         # 本标签字典项生成
 77         result[label_name] = {
 78             dir        : dir_name,
 79             training   : training_images,
 80             testing    : testing_images,
 81             validation : validation_images
 82         }
 83     return result
 84 
 85 def get_random_cached_bottlenecks(sess,n_class,image_lists,batch,category,jpeg_data_tensor,bottleneck_tensor):
 86     ‘‘‘
 87     函数随机获取一个batch的图片作为训练数据
 88     :param sess: 
 89     :param n_class: 
 90     :param image_lists: 
 91     :param how_many: 
 92     :param category:            training or validation
 93     :param jpeg_data_tensor: 
 94     :param bottleneck_tensor: 
 95     :return:                    瓶颈张量输出 & label
 96     ‘‘‘
 97     bottlenecks   = []
 98     ground_truths = []
 99     for i in range(batch):
100         label_index = random.randrange(n_class)              # 标签索引随机生成
101         label_name  = list(image_lists.keys())[label_index]  # 标签名获取
102         image_index = random.randrange(65536)                # 标签内图片索引随机种子
103         # 瓶颈层张量
104         bottleneck = get_or_create_bottleneck(               # 获取对应标签随机图片瓶颈张量
105             sess,image_lists,label_name,image_index,category,
106             jpeg_data_tensor,bottleneck_tensor)
107         ground_truth = np.zeros(n_class,dtype=np.float32)
108         ground_truth[label_index] = 1.0                      # 标准结果[0,0,1,0...]
109         # 收集瓶颈张量和label
110         bottlenecks.append(bottleneck)
111         ground_truths.append(ground_truth)
112     return bottlenecks,ground_truths
113 
114 def get_or_create_bottleneck(
115         sess,image_lists,label_name,index,category,jpeg_data_tensor,bottleneck_tensor):
116     ‘‘‘
117     寻找已经计算且保存下来的特征向量,如果找不到则先计算这个特征向量,然后保存到文件
118     :param sess: 
119     :param image_lists:       全图像字典
120     :param label_name:        当前标签
121     :param index:             图片索引
122     :param category:          training or validation
123     :param jpeg_data_tensor: 
124     :param bottleneck_tensor: 
125     :return: 
126     ‘‘‘
127     label_lists  = image_lists[label_name]          # 本标签字典获取 标签:{文件夹:str,训练:[],验证:[],测试:[]}
128     sub_dir      = label_lists[dir]               # 获取标签值
129     sub_dir_path = os.path.join(CACHE_DIR,sub_dir)  # 保存文件路径
130     if not os.path.exists(sub_dir_path):os.mkdir(sub_dir_path)
131     bottleneck_path = get_bottleneck_path(image_lists,label_name,index,category)
132     if not os.path.exists(bottleneck_path):
133         image_path = get_image_path(image_lists, INPUT_DATA, label_name, index, category)
134         #image_data = http://www.mamicode.com/gfile.FastGFile(image_path,‘rb‘).read()
135         image_data = http://www.mamicode.com/open(image_path,rb).read()
136         # print(gfile.FastGFile(image_path,‘rb‘).read()==open(image_path,‘rb‘).read())
137         # 生成向前传播后的瓶颈张量
138         bottleneck_values = run_bottleneck_on_images(sess,image_data,jpeg_data_tensor,bottleneck_tensor)
139         # list2string以便于写入文件
140         bottleneck_string = ,.join(str(x) for x in bottleneck_values)
141         # print(bottleneck_values)
142         # print(bottleneck_string)
143         with open(bottleneck_path, w) as bottleneck_file:
144             bottleneck_file.write(bottleneck_string)
145     else:
146         with open(bottleneck_path, r) as bottleneck_file:
147             bottleneck_string = bottleneck_file.read()
148         bottleneck_values = [float(x) for x in bottleneck_string.split(,)]
149     # 返回的是list注意
150     return bottleneck_values
151 
152 def run_bottleneck_on_images(sess,image_data,jpeg_data_tensor,bottleneck_tensor):
153     ‘‘‘
154     使用加载的训练好的Inception-v3模型处理一张图片,得到这个图片的特征向量。
155     :param sess:              会话句柄
156     :param image_data:        图片文件句柄
157     :param jpeg_data_tensor:  输入张量句柄
158     :param bottleneck_tensor: 瓶颈张量句柄
159     :return:                  瓶颈张量值
160     ‘‘‘
161     # print(‘input:‘,len(image_data))
162     bottleneck_values = sess.run(bottleneck_tensor,feed_dict={jpeg_data_tensor:image_data})
163     bottleneck_values = np.squeeze(bottleneck_values)
164     # print(‘bottle:‘,len(bottleneck_values))
165     return bottleneck_values
166 
167 def get_bottleneck_path(image_lists, label_name, index, category):
168     ‘‘‘
169     获取一张图片的中转(featuremap)地址(添加txt)
170     :param image_lists:   全图片字典
171     :param label_name:    标签名
172     :param index:         随机数索引
173     :param category:      training or validation
174     :return:              中转(featuremap)地址(添加txt)
175     ‘‘‘
176     return get_image_path(image_lists, CACHE_DIR, label_name, index, category) + .txt
177 
178 def get_image_path(image_lists, image_dir, label_name, index, category):
179     ‘‘‘
180     通过类别名称、所属数据集和图片编号获取一张图片的中转(featuremap)地址(无txt)
181     :param image_lists: 全图片字典
182     :param image_dir:   外层文件夹(内部是标签文件夹)
183     :param label_name:  标签名
184     :param index:       随机数索引
185     :param category:    training or validation
186     :return:            图片中间变量地址
187     ‘‘‘
188     label_lists   = image_lists[label_name]
189     category_list = label_lists[category]       # 获取目标category图片列表
190     mod_index     = index % len(category_list)  # 随机获取一张图片的索引
191     base_name     = category_list[mod_index]    # 通过索引获取图片名
192     return os.path.join(image_dir,label_lists[dir],base_name)
193 
194 def get_test_bottlenecks(sess,image_lists,n_class,jpeg_data_tensor,bottleneck_tensor):
195     ‘‘‘
196     获取全部的测试数据,计算输出
197     :param sess: 
198     :param image_lists: 
199     :param n_class: 
200     :param jpeg_data_tensor: 
201     :param bottleneck_tensor: 
202     :return:                   瓶颈输出 & label
203     ‘‘‘
204     bottlenecks  = []
205     ground_truths = []
206     label_name_list = list(image_lists.keys())
207     for label_index,label_name in enumerate(image_lists[label_name_list]):
208         category = testing
209         for index, unused_base_name in enumerate(image_lists[label_name][category]): # 索引, {文件名}
210             bottleneck = get_or_create_bottleneck(
211                 sess, image_lists, label_name, index,
212                 category, jpeg_data_tensor, bottleneck_tensor)
213             ground_truth = np.zeros(n_class, dtype=np.float32)
214             ground_truth[label_index] = 1.0
215             bottlenecks.append(bottleneck)
216             ground_truths.append(ground_truth)
217     return bottlenecks, ground_truths
218 
219 def main():
220     # 生成文件字典
221     images_lists = creat_image_lists(VALIDATION_PERCENTAGE,TEST_PERCENTAGE)
222     # 记录label种类(字典项数)
223     n_class = len(images_lists.keys())
224 
225     # 加载模型
226     # with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),‘rb‘) as f:   # 阅读器上下文
227     with open(os.path.join(MODEL_DIR, MODEL_FILE), rb) as f:            # 阅读器上下文
228         graph_def = tf.GraphDef()                                         # 生成图
229         graph_def.ParseFromString(f.read())                               # 图加载模型
230     # 加载图上节点张量(按照句柄理解)
231     bottleneck_tensor,jpeg_data_tensor = tf.import_graph_def(             # 从图上读取张量,同时导入默认图
232         graph_def,
233         return_elements=[BOTTLENECK_TENSOR_NAME,JPEG_DATA_TENSOR_NAME])
234 
235     ‘‘‘新的神经网络‘‘‘
236     # 输入层,由原模型输出层feed
237     bottleneck_input   = tf.placeholder(tf.float32,[None,BOTTLENECK_TENSOR_SIZE],name=BottleneckInputPlaceholder)
238     ground_truth_input = tf.placeholder(tf.float32,[None,n_class]               ,name=GroundTruthInput)
239     # 全连接层
240     with tf.name_scope(final_train_ops):
241         Weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE,n_class],stddev=0.001))
242         biases  = tf.Variable(tf.zeros([n_class]))
243         logits  = tf.matmul(bottleneck_input,Weights) + biases
244         final_tensor = tf.nn.softmax(logits)
245     # 交叉熵损失函数
246     cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=ground_truth_input))
247     # 优化算法选择
248     train_step    = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy)
249 
250     # 正确率
251     with tf.name_scope(evaluation):
252         correct_prediction = tf.equal(tf.argmax(final_tensor,1),tf.argmax(ground_truth_input,1))
253         evaluation_step    = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
254 
255     with tf.Session() as sess:
256         init = tf.global_variables_initializer()
257         sess.run(init)
258         for i in range(STEP):
259             # 随机batch获取瓶颈输出 & label
260             train_bottlenecks,train_ground_truth = get_random_cached_bottlenecks(
261                 sess,n_class,images_lists,BATCH,training,jpeg_data_tensor,bottleneck_tensor)
262             sess.run(train_step,feed_dict={bottleneck_input:train_bottlenecks,ground_truth_input:train_ground_truth})
263 
264             # 每迭代100次运行一次验证程序
265             if i % 100 == 0 or i + 1 == STEP:
266                 validation_bottlenecks, validation_ground_truth = get_random_cached_bottlenecks(
267                     sess, n_class, images_lists, BATCH, validation, jpeg_data_tensor, bottleneck_tensor)
268                 validation_accuracy = sess.run(evaluation_step, feed_dict={
269                     bottleneck_input: validation_bottlenecks, ground_truth_input: validation_ground_truth})
270                 print(Step %d: Validation accuracy on random sampled %d examples = %.1f%% %
271                       (i, BATCH, validation_accuracy * 100))
272 
273         test_bottlenecks,test_ground_truth = get_test_bottlenecks(
274             sess,images_lists,n_class,jpeg_data_tensor,bottleneck_tensor)
275         test_accuracy = sess.run(evaluation_step,feed_dict={
276             bottleneck_input:test_bottlenecks,ground_truth_input:test_ground_truth})
277         print(Final test accuracy = %.1f%% % (test_accuracy * 100))
278 
279 if __name__ == __main__:
280     main()

问题&建议:

1.建议从main函数开始阅读,跳到哪里读到那里;

2.我给的注释很详尽,原书《TensorFlow实战Google深度学习框架》也有更为详尽的注释,所以这里不多说了;

3.比较有借鉴意义的两点:

  • 如何使用把自己的图片数据导入框架中训练测试
  • 如何加载模型,import模型中的张量(在源代码以及[置顶]『TensorFlow』常用函数实践笔记给出了介绍)

4.一个有意思的测试:

在读取图片之前加入PIL包的读取,

1 img = np.asarray(Image.open(image_path))
2 print(np.prod(img.shape))
3 print(img.shape)

在单张图片向前传播中加入了输入图片数据和输出,

 1 def run_bottleneck_on_images(sess,image_data,jpeg_data_tensor,bottleneck_tensor):
 2     ‘‘‘
 3     使用加载的训练好的Inception-v3模型处理一张图片,得到这个图片的特征向量。
 4     :param sess:              会话句柄
 5     :param image_data:        图片文件句柄
 6     :param jpeg_data_tensor:  输入张量句柄
 7     :param bottleneck_tensor: 瓶颈张量句柄
 8     :return:                  瓶颈张量值
 9     ‘‘‘
10     bottleneck_values = sess.run(bottleneck_tensor,feed_dict={jpeg_data_tensor:image_data})
11     bottleneck_values = np.squeeze(bottleneck_values)
12     print(input:,len(image_data))
13     print(bottle:,len(bottleneck_values))
14     return bottleneck_values

输出挺有意思,

230400
(240, 320, 3)
input: 45685 <class ‘bytes‘>
2048
...
172800 (240, 240, 3) input: 30673 <class bytes> 2048

即是说feed的数据是原始的二进制文件,而且即使输入大小不同,输出大小是一致的(也就是说原网络是有裁剪数据的),所以有两个要搞明白的问题:

  • Inception-v3网络的tensorflow的源码
  • 继续研究一下其他的输入数据的方式,交叉印证一下到底怎么传入图片数据,是不是只能二进制输入

更新:

Google的Inception_v3源码

源码好难懂,inception_v3结构也过于复杂,没看明白,不过还是有收获的,可视化图:

 1 import os
 2 import tensorflow as tf
 3 
 4 inception_graph_def_file = os.path.join(./, tensorflow_inception_graph.pb)
 5 with tf.Session() as sess:
 6     with tf.gfile.FastGFile(inception_graph_def_file, rb) as f:
 7         graph_def = tf.GraphDef()
 8         graph_def.ParseFromString(f.read())
 9         tf.import_graph_def(graph_def, name=‘‘)
10     writer = tf.summary.FileWriter(./, sess.graph)
11     writer.close()

有关图片输入:

 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 3 
 4 # 使用‘r‘会出错,无法解码,只能以2进制形式读取
 5 # img_raw = tf.gfile.FastGFile(‘./123.png‘,‘rb‘).read()
 6 img_raw = open(./123.png,rb).read()
 7 
 8 # 把二进制文件解码为uint8
 9 img_0 = tf.image.decode_png(img_raw)
10 # img_1 = tf.image.convert_image_dtype(img_0,dtype=tf.uint8)
11 
12 sess = tf.Session()
13 print(sess.run(img_0).shape)
14 plt.imshow(sess.run(img_0))
15 plt.show()

原始读取的是二进制文件,强行‘r‘会出错,因为解码方式不对(utf-8之类都是文字解码器),之后使用tf的解码器可以解码成uint8的可读数组文件,tf.image.convert_image_dtype(img_0,dtype=tf.float32)用于后续处理,对图像的预处理之类的。

也就是说我们feed二进制代码之后原模型可以把它当作原始图片文件,进行解码切割操作,实际上训练的还是解码后的矩阵文件,联想到输入层节点的名称DecodeJpeg/contents:0,问题就解决了。

 

 

 

 

『TensorFlow』迁移学习_他山之石,可以攻玉