tensorflow 动态数组随时可以读取
import tensorflow as tf
ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
ta = ta.write(0, 10)
ta = ta.write(1, 20)
ta = ta.write(2, 30)
print(ta.read(0))
print(ta.read(1))
print(ta.read(2))
print(ta.stack())
tf.Tensor(10.0, shape=(), dtype=float32)
tf.Tensor(20.0, shape=(), dtype=float32)
tf.Tensor(30.0, shape=(), dtype=float32)
tf.Tensor([10. 20. 30.], shape=(3,), dtype=float32)
@tf.function
def fibonacci(n):
n=5
ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
ta = ta.unstack([0., 1.])
for i in range(2, n):
ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2))
return ta.stack()
fibonacci(7)
<tf.Tensor: shape=(5,), dtype=float32, numpy=array([0., 1., 1., 2., 3.], dtype=float32)>
v = tf.Variable(1)
@tf.function
def f(x):
ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
for i in tf.range(x):
v.assign_add(i)
ta = ta.write(i, v)
return ta.stack()
f(5)
<tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11], dtype=int32)>