hook勾函数

什么是Hook

img

钩子函数(hook function),顾名思义,可以理解是一个挂钩,是指在执行函数和目标函数之间挂载的函数, 框架开发者给调用方提供一个point -挂载点, 至于挂载什么函数有我们调用方决定, 这样大大提高了灵活性

hook函数和我们常听到另外一个名称:回调函数(callback function)功能是类似的,可以按照同种模式来理解。

hook实现示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Runner:
def __init__(self):
self._hooks = []

def register_hook(self, hook):
self._hooks.append(hook)

def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)

def train(self):
self.a = 10
self.b = 20
self.call_hook('before_train_epoch')
print('Done Epoch!')
self.call_hook('after_train_epoch')


class Hook:
def before_train_epoch(self, runner):
pass

def after_train_epoch(self, runner):
pass


class AddHook(Hook):
def before_train_epoch(self, runner):
print('i am Add')
print(f'Add {runner.a} and {runner.b} equal {runner.a + runner.b}\n')


class MulHook(Hook):
def before_train_epoch(self, runner):
print('i am Mul')
print(f'Add {runner.a} and {runner.b} equal {runner.a * runner.b}\n')


class ExpHook(Hook):
def after_train_epoch(self, runner):
print('i am Exp')
print(f'Exp {runner.a} and {runner.b} equal {runner.a ** runner.b}\n')


class Trainer:
def __init__(self):
self.runner = Runner()
self.runner.register_hook(MulHook())
self.runner.register_hook(ExpHook())
self.runner.register_hook(AddHook())

def run(self):
self.runner.train()


if __name__ == '__main__':
trainer = Trainer()
trainer.run()

hook在开源框架中的应用

keras

在深度学习训练流程中,hook函数体现的淋漓尽致。

一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:

  • 开始训练
  • 训练一个epoch前
  • 训练一个batch前
  • 训练一个batch后
  • 训练一个epoch后
  • 评估验证集
  • 结束训练

这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。

keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@keras_export('keras.callbacks.Callback')
class Callback(object):
"""Abstract base class used to build new callbacks.
Attributes:
params: Dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: Instance of `keras.models.Model`.
Reference of the model being trained.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch (see method-specific docstrings).
"""

def __init__(self):
self.validation_data = None # pylint: disable=g-missing-from-attributes
self.model = None
# Whether this Callback should only run on the chief worker in a
# Multi-Worker setting.
# TODO(omalleyt): Make this attr public once solution is stable.
self._chief_worker_only = None
self._supports_tf_logs = False

def set_params(self, params):
self.params = params

def set_model(self, model):
self.model = model

@doc_controls.for_subclass_implementers
@generic_utils.default
def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""

@doc_controls.for_subclass_implementers
@generic_utils.default
def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""

@doc_controls.for_subclass_implementers
def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only
be called during TRAIN mode.
Arguments:
epoch: Integer, index of epoch.
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""

@doc_controls.for_subclass_implementers
def on_epoch_end(self, epoch, logs=None):
"""Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only
be called during TRAIN mode.
Arguments:
epoch: Integer, index of epoch.
logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`.
"""

@doc_controls.for_subclass_implementers
@generic_utils.default
def on_train_batch_begin(self, batch, logs=None):
"""Called at the beginning of a training batch in `fit` methods.
Subclasses should override for any actions to run.
Arguments:
batch: Integer, index of batch within the current epoch.
logs: Dict, contains the return value of `model.train_step`. Typically,
the values of the `Model`'s metrics are returned. Example:
`{'loss': 0.2, 'accuracy': 0.7}`.
"""
# For backwards compatibility.
self.on_batch_begin(batch, logs=logs)

@doc_controls.for_subclass_implementers
@generic_utils.default
def on_train_batch_end(self, batch, logs=None):
"""Called at the end of a training batch in `fit` methods.
Subclasses should override for any actions to run.
Arguments:
batch: Integer, index of batch within the current epoch.
logs: Dict. Aggregated metric results up until this batch.
"""
# For backwards compatibility.
self.on_batch_end(batch, logs=logs)
...

总结

本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:

  • hook函数是流程中预定义好的一个步骤,没有实现
  • 挂载或者注册时, 流程执行就会执行这个钩子函数
  • 回调函数和hook函数功能上是一致的
  • hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数

参考

https://blog.csdn.net/pdcfighting/article/details/111243722


hook勾函数
http://example.com/2021/06/30/2021-06-30-hook勾函数/
作者
NSX
发布于
2021年6月30日
许可协议