0
点赞
收藏
分享

微信扫一扫

tf.Module

捌柒陆壹 2023-01-16 阅读 136


通过继承tf.Module代替object任何tf.Variable或tf.Module分配给对象的属性的实例可以使用被收集variables , trainable_variables或submodules属性:

import tensorflow as tf 


class Dense(tf.Module):
def __init__(self, in_features, out_features, name=None):
super(Dense, self).__init__(name=name)
self.w = tf.Variable(
tf.random.normal([in_features, out_features]), name='w')
self.b = tf.Variable(tf.zeros([out_features]), name='b')
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)

d = Dense(in_features=3, out_features=2)
d(tf.ones([1, 3]))

: <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.       , 0.5372393]], dtype=float32)>


举报

相关推荐

0 条评论