关于 TensorFlow 的一些零散知识
TensorFlow 中的内容相当繁杂, 及时总结是一个好习惯; 平时会收集/总结一些有用的知识点和代码片段, 放在本篇博文下是很合适的. 嘻嘻, 我就是想水一篇文章… ????????????
广而告之
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中;
变量初始化
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
前两条代码可以处理 Attempting to use uninitialized value
的问题, 最后一条用于处理 LookUpTable not initialized
的问题: 在使用 feature_column
时, 由于 feature
需要查表获取, 这个表也需要进行初始化, 比如:
FailedPreconditionError (see above for traceback): Table not initialized.
[[node hash_table_Lookup (defined at 5.py:23) = LookupTableFindV2[Tin=DT_STRING, Tout=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](relationship_lookup/hash_table, to_sparse_input_1/values, relationship_lookup/hash_table/Const)]]
获取 TensorFlow 中变量或者 Op 的 Name
all_vars = tf.global_variables()
for v in all_vars:
print(v.op.name)
graph = tf.get_default_graph()
for op in graph.get_operations():
print(op.name)
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
arr = sess.run(var)
from tensorflow.python.framework import ops
print(tf.get_collection(ops.GraphKeys.MODEL_VARIABLES))
读取 Estimator 对象的 Variable
names = linear_est.get_variable_names()
print('name: ', names)
for i in names:
print(type(linear_est.get_variable_value(i)))
还有一种方法, 来自: can tf.estimator.Estimator’s parameters be modified by hand?,
通过访问 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
来达到目的, 但如果为了得到模型的权重, 而不是整张图上的变量, 应该访问:
from tensorflow.python.framework import ops
print(tf.get_collection(ops.GraphKeys.MODEL_VARIABLES))
上面链接中的代码如下:
# Restore, Update, Save
# tested only on tesorflow 1.4
import tensorflow as tf
tf.reset_default_graph()
CHECKPOINT_DIR = 'CHECKPOIN_DIR' # for example '/my_checkpoints' as in tf.estimator.LinearClassifier(model_dir='/my_checkpoints'...
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(checkpoint.model_checkpoint_path + '.meta')
saver.restore(sess, checkpoint.model_checkpoint_path)
# just to check all variables values
# sess.run(tf.all_variables())
# get your variable
KEY = 'linear/linear_model/0/weights/part_0:0'# for tf.estimator.LinearClassifier first weight
var_wights_0 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == KEY][0]
sess.run(var_wights_0)
# your update operation
var_wights_0_updated = var_wights_0.assign(var_wights_0 - 100)
sess.run(var_wights_0_updated)
# you can check that value is updated
# sess.run(tf.all_variables())
# this saves updated values to last checkpoint saved by estimator
saver.save(sess, checkpoint.model_checkpoint_path)
TensorFlow 将整数转化为字符串
使用 tf.string.format
, 来自 Tensorflow - How to Convert int32 to string (using Python API for Tensorflow)
import tensorflow as tf
x = tf.constant([1, 2, 3], dtype=tf.int32)
x_as_string = tf.map_fn(lambda xi: tf.strings.format('{}', xi), x, dtype=tf.string)
with tf.Session() as sess:
res = sess.run(x_as_string)
print(res)
# [b'1' b'2' b'3']
tf.data 介绍
- https://www.tensorflow.org/guide/data
- https://www.tensorflow.org/guide/data_performance
- 十图详解tensorflow数据读取机制(附代码)
- tf.data.Dataset.interleave() 最通俗易懂的使用详解(有图有真相)
- How to use parallel_interleave in TensorFlow
- Tensorflow踩坑记之tf.data
tf.identity 的作用
- StackOverFlow: tf.identity 的作用
总的来说, 主要是两个, tf.identity
相当于创建了一个和原始结果一样的新节点, 可以和各种控制流的 op 配合使用, 具体看链接中的例子; 另一个是给 op 命名.