解决tensorflow 的 Saver.restore()无法从本地读取变量的问题
最近做tensorflow 手写数字识别的时候遇到了一个问题,Saver的restore()方法无法从本地恢复变量,导致了每次都会重新训练。
原来代码
saver = tf.train.Saver(max_to_keep=5)
epoch = tf.Variable(0, name='epoch', trainable=False)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
ckpt_dir = "./model/"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt = tf.train.latest_checkpoint(ckpt_dir)
if ckpt != None:
saver.restore(sess, ckpt)
else:
print('Train from scratch')
start = sess.run(epoch)
修改代码
epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
ckpt_dir = "./model/"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt = tf.train.latest_checkpoint(ckpt_dir)
if ckpt != None:
saver.restore(sess, ckpt)
else:
print('Train from scratch')
start = sess.run(epoch)
其实主要改变的就是以下两行的顺序
epoch = tf.Variable(0, name='epoch', trainable=False) saver = tf.train.Saver(max_to_keep=5)
|