tf.nn.embedding_lookup即在给定的范围内做映射
下面直接看例子
样例1
t = np.asarray([1,2,3,0])
params = tf.constant([10,20,30,40])
embedded_inputs = tf.nn.embedding_lookup(params, t)
with tf.Session() as sess:
print(sess.run(embedded_inputs))
结果
[20 30 40 10]
可见结果即是按照t的顺序,将params中查找出来
样例2
t = np.asarray([1,2,3,0])
with tf.variable_scope('test', reuse=tf.AUTO_REUSE):
t = tf.convert_to_tensor(t)
embedding_table = tf.Variable([10,20,30,40],
name='phone_embedding', shape=[4],
dtype=tf.float32)
embedded_inputs = tf.nn.embedding_lookup(embedding_table, t)
init_ops = [tf.global_variables_initializer(),
tf.local_variables_initializer()]
with tf.Session() as sess:
sess.run(init_ops)
print(sess.run(t))
print(sess.run(embedding_table))
print(sess.run(embedded_inputs))
结果
[1 2 3 0]
[10. 20. 30. 40.]
[20. 30. 40. 10.]
样例3
索引值为矩阵
#Index t is a 2D array
t = np.asarray([1,2,3,0]).reshape([2,2])
with tf.variable_scope('test', reuse=tf.AUTO_REUSE):
t = tf.convert_to_tensor(t)
embedding_table = tf.Variable([10,20,30,40],
name='phone_embedding', shape=[4],
dtype=tf.float32)
embedded_inputs = tf.nn.embedding_lookup(embedding_table, t)
init_ops = [tf.global_variables_initializer(),
tf.local_variables_initializer()]
with tf.Session() as sess:
sess.run(init_ops)
print(sess.run(t))
print(sess.run(embedding_table))
print(sess.run(embedded_inputs))
结果
[[1 2]
[3 0]]
[10. 20. 30. 40.]
[[20. 30.]
[40. 10.]]
embedding_table不用做对应的shape转换,不然会抛出错误