首页 > 代码库 > 转载:tensorflow保存训练后的模型
转载:tensorflow保存训练后的模型
训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存。如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值。建议可以使用Saver类保存和加载模型的结果。
1、使用tf.train.Saver.save()方法保存模型
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix=‘meta‘, write_meta_graph=True, write_state=True)
- sess: 用于保存变量操作的会话。
- save_path: String类型,用于指定训练结果的保存路径。
- global_step: 如果提供的话,这个数字会添加到save_path后面,用于构建checkpoint文件。这个参数有助于我们区分不同训练阶段的结果。
2、使用tf.train.Saver.restore方法价值模型
tf.train.Saver.restore(sess, save_path)
- sess: 用于加载变量操作的会话。
- save_path: 同保存模型是用到的的save_path参数。
下面通过一个代码演示这两个函数的使用方法
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ‘‘
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = http://www.mamicode.com/np.reshape(np.random.rand(10).astype(np.float32), (10, 1))>
转载:tensorflow保存训练后的模型
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。