首页 > 代码库 > 『TensorFlow』徒手装高达_战斗数据收集模块原型_save&restore
『TensorFlow』徒手装高达_战斗数据收集模块原型_save&restore
顺便一提,上节定义的网络结构有问题,现已修改,之后会陆续整理上来。
两种常用(我会的)的加载方式:
1.
‘‘‘ 使用原网络保存的模型加载到自己重新定义的图上 可以使用python变量名加载模型,也可以使用节点名 ‘‘‘ import AlexNet as Net import AlexNet_train as train import random import tensorflow as tf IMAGE_PATH = ‘./flower_photos/daisy/5673728_71b8cb57eb.jpg‘ with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3]) y = Net.inference_1(x, N_CLASS=5, train=False) with tf.Session() as sess: # 程序前面得有 Variable 供 save or restore 才不报错 # 否则会提示没有可保存的变量 saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(‘./model/‘) img_raw = tf.gfile.FastGFile(IMAGE_PATH, ‘rb‘).read() img = sess.run(tf.expand_dims(tf.image.resize_images( tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0)) if ckpt and ckpt.model_checkpoint_path: print(ckpt.model_checkpoint_path) saver.restore(sess,‘./model/model.ckpt-0‘) global_step = ckpt.model_checkpoint_path.split(‘/‘)[-1].split(‘-‘)[-1] res = sess.run(y, feed_dict={x: img}) print(global_step,sess.run(tf.argmax(res,1)))
2.
‘‘‘ 直接使用使用保存好的图 无需加载python定义的结构,直接使用节点名称加载模型 由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错 现阶段不推荐使用,以后如果理解深入了可能会找到使用方法 ‘‘‘ import AlexNet_train as train import random import tensorflow as tf IMAGE_PATH = ‘./flower_photos/daisy/5673728_71b8cb57eb.jpg‘ # x = tf.placeholder( # tf.float32, [1, train.INPUT_SIZE[0],train.INPUT_SIZE[1], 3], name=‘Placeholder‘) ckpt = tf.train.get_checkpoint_state(‘./model/‘) saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +‘.meta‘) with tf.Session() as sess: saver.restore(sess,ckpt.model_checkpoint_path) img_raw = tf.gfile.FastGFile(IMAGE_PATH, ‘rb‘).read() img = sess.run(tf.image.resize_images( tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3))) imgs = [] for i in range(128): imgs.append(img) print(sess.run(tf.get_default_graph().get_tensor_by_name(‘fc3:0‘),feed_dict={‘Placeholder:0‘: imgs})) ‘‘‘ img = sess.run(tf.expand_dims(tf.image.resize_images( tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0)) print(img) imgs = [] for i in range(128): imgs.append(img) print(sess.run(tf.get_default_graph().get_tensor_by_name(‘conv1:0‘), feed_dict={‘Placeholder:0‘:img}))
注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。
『TensorFlow』徒手装高达_战斗数据收集模块原型_save&restore
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。