0
点赞
收藏
分享

微信扫一扫

Keras 自定义层


Keras自定义或者重写层,需要实现三个方法:

  • ​build(input_shape)​​​这里主要是是定义权重,通过​​self.build=True​​​设置哪些参数参与训练,通常通过​​super([Layer],self).build()​​​调用父类的​​build​​函数完成
  • ​call(x)​​​编写层的功能逻辑的地方,通常只需要关注传入的第一个参数:输入张量,除非你希望你的层支持​​masking​​,这层就是输入张量到输出张量的计算过程。
  • ​compute_output_shape(input_shape)​​,如果你的层更改了输入张量的形状,这层定义输出张量的维度,这让Keras能自动推断各层的形状

问题:

  • 1.初看到自定义层都会对​​build​​​的​​input_shape​​参数产生疑问,实际上,我们在输入层会指定输入的维度,在每一层也会返回输出的维度,Keras也会根据计算图自动推断。
  • 2.重写layer的时候是否需要考虑batchsize?
    Keras的layer是一个Tensor到Tensor的映射,默认batch_size是保持不变,所以我们在Reshape变换维度时也不用传入batch_size维度

参考 ​​keras 自定义层​​

最后举一个conditional layer normalization的例子
​​​基于Conditional Layer Normalization的条件文本生成​​

# 自定义层需要实现三个方法
class LayerNormalization(Layer):
"""(Conditional) Layer Normalization
hidden_*系列参数仅为有条件输入时(conditional=True)使用
hidden_units 降维的维度,用于输入的条件矩阵过大,先降维再变换
hidden_activation 一般采用线性激活
"""
def __init__(
self,
center=True,
scale=True,
epsilon=None,
conditional=False,
hidden_units=None,
hidden_activation='linear',
hidden_initializer='glorot_uniform',
**kwargs
):
super(LayerNormalization, self).__init__(**kwargs)
self.center = center
self.scale = scale
self.conditional = conditional
self.hidden_units = hidden_units
self.hidden_activation = activations.get(hidden_activation)
self.hidden_initializer = initializers.get(hidden_initializer)
self.epsilon = epsilon or 1e-12

def build(self, input_shape):
super(LayerNormalization, self).build(input_shape) # self.built=True

if self.conditional:
shape = (input_shape[0][-1],)
else:
shape = (input_shape[-1],)

if self.center:
self.beta = self.add_weight(
shape=shape, initializer='zeros', name='beta'
)
if self.scale:
self.gamma = self.add_weight(
shape=shape, initializer='ones', name='gamma'
)

if self.conditional:

if self.hidden_units is not None:
# 用于降维
self.hidden_dense = Dense(
units=self.hidden_units,
activation=self.hidden_activation,
use_bias=False,
kernel_initializer=self.hidden_initializer
)

if self.center:
self.beta_dense = Dense(
units=shape[0], use_bias=False, kernel_initializer='zeros'
)
if self.scale:
self.gamma_dense = Dense(
units=shape[0], use_bias=False, kernel_initializer='zeros'
)

def call(self, inputs):
"""如果是条件Layer Norm,则默认以list为输入,第二个是condition
"""
if self.conditional:
inputs, cond = inputs
# 用于降维
if self.hidden_units is not None:
cond = self.hidden_dense(cond)
# 扩充维度保证与inputs维度相同
for _ in range(K.ndim(inputs) - K.ndim(cond)):
cond = K.expand_dims(cond, 1)
if self.center:
beta = self.beta_dense(cond) + self.beta
if self.scale:
gamma = self.gamma_dense(cond) + self.gamma
else:
if self.center:
beta = self.beta
if self.scale:
gamma = self.gamma

outputs = inputs
if self.center:
# layer normalization 取一个batch,一列的yi'yang
mean = K.mean(outputs, axis=-1, keepdims=True)
outputs = outputs - mean
if self.scale:
variance = K.mean(K.square(outputs), axis=-1, keepdims=True)
std = K.sqrt(variance + self.epsilon)
outputs = outputs / std
outputs = outputs * gamma
if self.center:
outputs = outputs + beta

return outputs
# input_shape是一个list 定义输出维度
def compute_output_shape(self, input_shape):
if self.conditional:
return input_shape[0]
else:
return input_shape
# 融合当前类和父类的config
def get_config(self):
config = {
'center': self.center,
'scale': self.scale,
'epsilon': self.epsilon,
'conditional': self.conditional,
'hidden_units': self.hidden_units,
'hidden_activation': activations.serialize(self.hidden_activation),
'hidden_initializer':
initializers.serialize(self.hidden_initializer),
}
base_config = super(LayerNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


举报

相关推荐

0 条评论