pytorch-lightning
参考资料:
pytorch-lightning 101: 视频课程, 可以用来理解概念, 共 4 小节课程, 其中第 3 小节是 pytorch-lightning 的基本用法, 第 4 小节介绍了 pytorch-lightning 的实现细节
LightningLite: 轻量版的 pytorch-lighting, 目前(2022.9.29)并未完全成熟. 用于尽量做很少的代码改动, 快速将 pytorch 训练代码进行转换, 好处是可以很快地将写好的单 GPU 或 CPU 训练流程变得可以自动支持多卡训练, fp16 训练等.
本文档主要分为两大部分。第一部分从使用 Lightning 的角度,介绍使用方法,以样例为主,尽量不涉及过多的源码。第二部分主要解释 Lightning 的源代码,有利于更好地使用。
第一部分:Lightning 的使用
Lightning 的使用疑惑
LightningModule中的self.log(...)是指什么?(猜测用于传给torchmetric, 类似于tf的Metric) 似乎最终调用的是pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log()此函数体内涉及到
lightning_utilities.core.apply_func.apply_to_collections
Pytorch vs Lightning
Dataset, DataLoader: 在 Lightning 可以沿用, 或者使用
LightningDataModule, 多卡训练时, Dataloader 所需的 DistributedSampler 在 Lightning 中无需手动写nn.Module: 在 Lightning 使用
LightningModule, 需要提供forward,training_step,configure_optimizers方法训练过程:
for loop: 在 Lightning 无需自己写 for loop
loss backward: 在 Lightning 中可以让
training_step返回 loss, 自动进行 backward, 或者也可以手工写 backward 的执行过程
DDP, AMP: 在 Lightning 中用
Trainer的初始化参数指定模型加载与保存: 最简单的用法是 Lightning 中用
Trainer自动处理, 高级用法是初始化Trainer时增加pytorch_lightning.callbacks.ModelCheckpoint这个 callback, 更复杂的用法是关闭Trainer的模型保存功能(enable_checkpointing=False), 在LightningModule的training_epoch_end或者validation_epoch_end中检测当前的 local_rank, 只在 local_rank 为0的进程上做保存模型的工作控制台打印: 可以使用 python 本身的
logging模块进行, 参考官方文档
使用模板
pytorch_lightning.LightningModule
代码参考自: pytorch-lightning with huggingface transfomers
inference without pytorch-lightning package
为了更好地模块化, 建议采用如下方式组织代码
控制backward的逻辑
更多详细内容参考官方文档
如果需要控制backward的逻辑, 需要做以下事情
在
__init__中设置self.automatic_optimization=False在
training_step中使用如下API:使用
optimizer=self.optimizers()获取所有优化器使用
optimizer.zero_grad()清除梯度使用
self.manual_backward(loss)而不要使用loss.backward()使用
optimizer.step()
多个optimizer与scheduler
建议 configure_optimizers 函数按如下方式返回
当前 rank、step、epoch 等
fit 函数伪代码(hook 编程)
从Lightning的实现上
hook 指的是类的一些特定方法, 例如:
on_train_epoch_start,on_train_batch_start。callback 指的是含有这些特定方法的类。
具体来说,简易版的大致实现方式如下
备注: 下面的写法仅做示意 hook 编程的大致形式, 并非对齐 LightningModule 的真正实现
在理解了 hook/callback 之后, 可以参考官方文档中真正的实现流程理解 fit 函数的整个过程, 并在 LightningModule 的子类中覆盖这些方法或者传入自定义 Callback 类, 摘录如下:
training_step 出入参
入参: 由以下这些 hook 最后的输出得到
出参: 返回一个标量版的loss即可, 或者返回一个字典, 字典中有一个键值对为{"loss": loss}
training_step, validation_step, test_step, predict_step
training_step, validation_step, test_step, predict_steptraining_step: 训练过程, batch中应包含x与y, 被trainer.fit调用
validation_step: 验证过程, batch中应包含x与y, 通常的每个epoch结束后被trainer.fit调用
test_step: 训练过程, batch中应包含x与y, 在trainer.fit中不被调用, 被trainer.test调用
predict_step: 在不定义predict_step的情况下, trainer.pred会调用model.forward, 否则会调用predict_step, 因此batch中只包含x即可
save checkpoint advanced
方式1 默认情况下, 会自动保存模型(只保存最新的), 参考官方文档
By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch, Checkpoints capture the exact value of all parameters used by a model. To disable automatic checkpointing, set this to False.
备注: 这种情况下默认会保存在lightening_logs/version_n/checkpoints 目录中, 且会保存许多其他非模型权重的东西
方式2 trainer 中传入的callbacks包含一个 ModelCheckpoint 实例, 参考官方文档。用于指定保存路径,保留多少个checkpoint,是否只保留权重,是否根据某个特定的监控值保存最优模型等
方式3 手写, 在各个hook中加上一些条件进行模型保存
方式4(不确定): 继承ModelCheckpoint
怎样控制 log 打印(未解决)
包含以下几个部分:
LightningModule.training_step的返回值在LightningModule.on_epoch_end中保存的具体逻辑是?如何控制 Tensorboard 的打印内容与打印时机
如何分别控制打印到控制台/保存到日志文件/输出到 Tensorboard 的信息
LightningModule 中的 self.log(...) 最终调用的是 trainer._results.log,而这个对象最终对应的是类似于trainer.fit_loop.epoch_loop._results,它是pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection 对象。
逻辑大约是module.log调用时,在 _results 中记录下信息,之后在 TrainingEpochLoop 的 advance函数中调用self.trainer._logger_connector.update_train_step_metrics()实际将 log 写入:具体进一步触发 LoggerConnector 中调用 self.log_metrics(self.trainer._results.metrics(not self._epoch_end_reached)['log']),根据是否需要写日志,再触发 TensorBoardLogger 的 log_metrics 函数。
trainer的构造函数__init__中有一个参数log_every_n_steps用于控制update_train_step_metrics函数将数据写入 tensorboard也就是
module.log()最终修改了trainer._results, 然后利用这个在某个时机写入了 tensorboard
备注:由于 TensorBoardLogger 中没有保存图像的函数,因此,如果想完全发挥它内部包含的 torch.utils.tensorboard import SummaryWriter。可以通过 TensorBoardLogger.experiment进行直接实现。
例如(待验证):
pytorch_lightning.Trainer
备注:相比于 huggingface transformers 中的 Trainer 类,官方文档中鼓励对其使用继承的方法重写一些方法。pytorch-lightning 中的推荐做法是直接使用 Trainer,而对 LightningModule 进行继承以及方法重写。
pytorch_lightning.LightningDataModule
第二部分:Lightning 源码阅读
Lightning 源码阅读OpenMMLab对pytorch-lightning也有一篇源码解读文章: https://zhuanlan.zhihu.com/p/389271556
目标: 理解如下代码片段的执行过程
LightningModule
LightningModule父类
源码中关于 LightningModule 类的定义继承自了多个父类, 特别注意它也继承自torch.nn.Module。因此需要先对几个父类的代码做个了解
Trainer.__init__
Trainer.__init__Trainer 类没有父类, 直接继承自 object.
第 2 行代码: trainer = Trainer(...) 的源码如下:
首先解释一下这个装饰器的作用:
利用 os.environ.get 方法获取形如 PL_TRAINER_{XXX} 环境变量, 并用环境变量的值取代被装饰的函数(上面的例子中为Trainer.__init__函数)中的默认值. 即最终第 2 行代码参数设定的优先顺序为:
接下来进入 Trainer.__init__ 函数的函数体, 完整源代码如下:
整体流程简要描述:
总的来说基本上是一些为Trainer的属性赋值的操作
需要细致展开的部分如下:
Trainer.fit
Trainer.fit完整源代码如下:
Trainer._call_callback_hooks(hook_name)
Trainer.fit 方法中也会多次调用Trainer._call_callback_hooks方法, 其完整源代码如下:
注: 为何要对"on_init_start", "on_init_end"这两个做单独的处理? 因为其他的hook_name都在Trainer.fit方法内部被调用,Trainer.fit方法的源代码如下:
其他补充
apply_to_collection
lightning_utilities 中包含一些递归应用函数(具体实现可以参看源码)
由于 pytorch-lightning 很多地方将这两个函数用于 dataclass 上,因此补充一个示例:
Last updated
Was this helpful?