0
点赞
收藏
分享

微信扫一扫

(一)tensorflow2.0 - 自定义layer


最近在用tensorflow2.0搭建一个简单的神经网络,虽然结构简单但是由于对自定义有要求,官方提供的layer和model不能满足要求,因此需要自行对layer、model、loss function进行自定义。由于tensorflow2.0发布不久,国内相关文章较少,我便决定写上这一系列文章。本文讨论tensorflow2.0中如何自定义layer。

​​(一)tensorflow2.0 - 自定义layer​​(二)tensorflow2.0 - 自定义Model
(三)tensorflow2.0 - 自定义loss function(损失函数)
(四)tensorflow2.0 - 实战稀疏自动编码器SAE

本文不讨论tensorflow1和2在版本上自定义layer的区别,只讲述2.0版本下如何自定义layer。

本文架构上不做长篇大论,直接根据代码来解释如何自定义模型。

首先引入相应的库函数

import tensorflow as tf
from tensorflow.keras import *

然后自定义Layer类,这里命名为SAELayer,继承自​​tensorflow.keras.layers.Layer​​​,由于上面引入的库函数为​​from tensorflow.keras import *​​​,所以写起来就比较清爽,可以直接简写为​​layers.Layer​​​,之后的都如此,写法上​​tensorflow.keras​​都省略了,就不做多解释。

需要注意,Layer类中涉及到了三个重要的方法,分别是​​__init__()​​​,​​build()​​​,​​call()​​​,关于他们的关系与作用请看我的另一篇文章(​​tensorflow2.0中Layer的__init__(),build(), call()函数​​)。这里只简单说明,​​__init__()​​​函数在创建​​Layer​​​对象时调用,​​build​​​在第一次调用​​call​​​前调用(只调用一次),往后使用​​Layer​​​的方法都是使用​​call()​​​的方法。
需要注意​​​build()​​​方法的参数,该方法是被自动调用的,所以其参数是固定的(当然改形参名称没关系),但是不能添加或者删除参数。而​​call()​​​方法的官方定义为​​Layer.call(inputs, **kwargs)​​​,因此它至少需要一个​​input​​作为参数(输入该层的数据),其他参数可以按需自定义

下例为进行一个简单的​​sigmoid(w*x + b)​​的功能的自定义层,当然这是一次对一批数据进行操作,所以需要用矩阵(张量)的方式来思考。

class SAELayer(layers.Layer):
# 初始化num_outputs,即当前层输出元素的个数
def __init__(self, num_outputs):
super(SAELayer, self).__init__()
self.num_outputs = num_outputs

# 在第一次调用该Layer的call方法前(自动)调用该函数,可以知道输入数据的shape
# 根据输入数据的shape可以初始化权值、bias的矩阵
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
self.bias = self.add_variable("bias",
shape=[self.num_outputs])
def call(self, input):
output = tf.matmul(input, self.kernel) + self.bias
# sigmoid激活函数
output = tf.nn.sigmoid(output)
return output

到此​​Layer​​​就定义好了,大家可以根据需要对其各部分进行修改,比如在​​build()​​​中增删参数、在​​call()​​中更改计算方式、激活函数等等。

Layer定义好了,如何使用呢?
大可以按照正常使用其他Layer的方式来调用,如果想看具体实例,可以看下一篇文章,里面将Layer放入了一个简单的自定义Model中进行使用。
​(二)tensorflow2.0 - 自定义Model​​

参考文献:

  • ​​Custom layers​​
  • ​​第四章:Tensorflow 2.0 实现自定义层和自定义模型的编写并实现cifar10 的全连接网络(理论+实战)​​
  • ​​tensorflow2.0中Layer的__init__(),build(), call()函数​​


举报

相关推荐

0 条评论