通过继承tf.Module代替object任何tf.Variable或tf.Module分配给对象的属性的实例可以使用被收集variables , trainable_variables或submodules属性:
import tensorflow as tf
class Dense(tf.Module):
def __init__(self, in_features, out_features, name=None):
super(Dense, self).__init__(name=name)
self.w = tf.Variable(
tf.random.normal([in_features, out_features]), name='w')
self.b = tf.Variable(tf.zeros([out_features]), name='b')
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
d = Dense(in_features=3, out_features=2)
d(tf.ones([1, 3]))
: <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0. , 0.5372393]], dtype=float32)>