首页 > 代码库 > 『TensorFlow』徒手装高达_战斗数据收集模块原型_save&restore

『TensorFlow』徒手装高达_战斗数据收集模块原型_save&restore

顺便一提,上节定义的网络结构有问题,现已修改,之后会陆续整理上来。
两种常用(我会的)的加载方式:
1.
‘‘‘
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
‘‘‘
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf

IMAGE_PATH = ‘./flower_photos/daisy/5673728_71b8cb57eb.jpg‘

with tf.Graph().as_default() as g:

    x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
    y = Net.inference_1(x, N_CLASS=5, train=False)

    with tf.Session() as sess:
        # 程序前面得有 Variable 供 save or restore 才不报错
        # 否则会提示没有可保存的变量
        saver = tf.train.Saver()

        ckpt = tf.train.get_checkpoint_state(‘./model/‘)
        img_raw = tf.gfile.FastGFile(IMAGE_PATH, ‘rb‘).read()
        img = sess.run(tf.expand_dims(tf.image.resize_images(
            tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))

        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess,‘./model/model.ckpt-0‘)
            global_step = ckpt.model_checkpoint_path.split(‘/‘)[-1].split(‘-‘)[-1]
            res = sess.run(y, feed_dict={x: img})
            print(global_step,sess.run(tf.argmax(res,1)))

2.

‘‘‘
直接使用使用保存好的图
无需加载python定义的结构,直接使用节点名称加载模型
由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错
现阶段不推荐使用,以后如果理解深入了可能会找到使用方法
‘‘‘
import AlexNet_train as train
import random
import tensorflow as tf

IMAGE_PATH = ‘./flower_photos/daisy/5673728_71b8cb57eb.jpg‘

# x = tf.placeholder(
#     tf.float32, [1, train.INPUT_SIZE[0],train.INPUT_SIZE[1], 3], name=‘Placeholder‘)

ckpt = tf.train.get_checkpoint_state(‘./model/‘)
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +‘.meta‘)

with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)

    img_raw = tf.gfile.FastGFile(IMAGE_PATH, ‘rb‘).read()
    img = sess.run(tf.image.resize_images(
        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
    imgs = []
    for i in range(128):
       imgs.append(img)
    print(sess.run(tf.get_default_graph().get_tensor_by_name(‘fc3:0‘),feed_dict={‘Placeholder:0‘: imgs}))

    ‘‘‘
    img = sess.run(tf.expand_dims(tf.image.resize_images(
        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
    print(img)
    imgs = []
    for i in range(128):
        imgs.append(img)
    print(sess.run(tf.get_default_graph().get_tensor_by_name(‘conv1:0‘),
                   feed_dict={‘Placeholder:0‘:img}))

 

注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。

 

『TensorFlow』徒手装高达_战斗数据收集模块原型_save&restore