0
点赞
收藏
分享

微信扫一扫

TensorFlow中常见内置回调Callback


class BaseLogger:

计算每个epoch周期的平均指标,这个回调已经被自动应用在每个Keras模型,所以不需要手动设置

callbacks = tf.keras.callbacks.BaseLogger(
stateful_metrics=None
)

model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)

class CSVLogger:

将每个epoch的评估及损失结果导入到一个CSV文件中

  • ​filename​​:CSV保存路径
  • ​separator​​:不同字段之间的分割符
  • ​append​​:是否在原来的文件基础之上追加

callbacks = tf.keras.callbacks.CSVLogger(
filename='./res.log',
separator=',',
append=False
)

model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)

class EarlyStopping:

当一个被监控的指标停止提升的时候停止训练

  • ​monitor​​:需要监控的指标或者损失
  • ​min_delta​​​:最小误差,只有两个epoch的评估值达到这个误差才会认为是一次变化,如果两次的误差小于​​min_delta​​则认为两次训练没有任何变化
  • ​patience​​​:连续没有改进的epoch数,如果连续​​patience​​个epoch还没有改进,则停止训练
  • ​verbose​​:详细模式,用户打印控制台日志
  • ​mode​​​:有三种模式,分别是​​min​​​,​​max​​​,​​auto​​,如果是min那么会判断如果监控的损失不在下降停止训练,如果是max,那么则发现监控的指标不在上升停止训练,如果是auto则会根据传进来的监控指标进行推断
  • ​baseline​​:监控指标的基线值,如果模型在基线上没有显示出改进,则训练将停止
  • ​restore_best_weights​​:是否从具有监控指标最佳值的epoch恢复模型权重

callbacks = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=1e-3,
patience=2,
verbose=0,
mode='min',
baseline=None,
restore_best_weights=False
)

model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)

class History:

将训练事件记录到​​history​​对象中,此回调会自动应用于每个 Keras 模型,history 对象由模型的 fit 方法返回。

模型训练后返回的history对象会包含训练时期每个epoch的精度或者损失值以及验证集的评估指标

class LearningRateScheduler:

学习率时间表

  • ​schedule​​:一个函数,它以epoch为索引(整数,从 0 开始索引)和当前学习率(浮点数)作为输入,并返回一个新的学习率作为输出(浮点数)。
  • ​verbose​​:是否打印学习更新情况

def scheduler(epoch, lr):
if epoch < 10:
return lr
else:
return lr * tf.math.exp(-0.1)


callbacks = tf.keras.callbacks.LearningRateScheduler(scheduler=scheduler,
verbose=1)

model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)

class ModelCheckpoint:

以某个频率保存 Keras 模型或模型权重的回调

  • ​filename​​:保存模型或者权重的路径
  • ​monitor​​:需要监测的损失或者评估指标
  • ​verbose​​:控制台输出状态
  • ​save_best_only​​:是否保存最好的模型
  • ​save_weights_only​​:是否只保存权重,否则是保存整个模型
  • ​mode​​​:监控模式,​​min​​​,​​max​​​,​​auto​​,是按照监控的评估指标来定,如果是损失选择min,如果是准确率这种选择max,如果是auto会根据传入的monitor自动推断
  • ​save_freq​​​:两种选择,分别是​​epoch​​​和​​integer​​,如果是epoch是每个epoch保存一次,如果是填写一个整数,代表每训练多少个批次保存一次
  • ​options​​:其它配置,用于保存模型或者参数

callbacks = tf.keras.callbacks.ModelCheckpoint(
filename='./save_model',
monitor='val_loss',
verbose=1,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None
)

model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)

class ProgbarLogger:

打印精度到标准输出


举报

相关推荐

0 条评论