pytorch-lightning

参考资料:

  • pytorch-lightning 101: 视频课程, 可以用来理解概念, 共 4 小节课程, 其中第 3 小节是 pytorch-lightning 的基本用法, 第 4 小节介绍了 pytorch-lightning 的实现细节

  • LightningLite: 轻量版的 pytorch-lighting, 目前(2022.9.29)并未完全成熟. 用于尽量做很少的代码改动, 快速将 pytorch 训练代码进行转换, 好处是可以很快地将写好的单 GPU 或 CPU 训练流程变得可以自动支持多卡训练, fp16 训练等.

本文档主要分为两大部分。第一部分从使用 Lightning 的角度,介绍使用方法,以样例为主,尽量不涉及过多的源码。第二部分主要解释 Lightning 的源代码,有利于更好地使用。

第一部分:Lightning 的使用

疑惑

  • LightningModule 中的 self.log(...) 是指什么?(猜测用于传给torchmetric, 类似于tf的Metric) 似乎最终调用的是pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log()

    • 此函数体内涉及到lightning_utilities.core.apply_func.apply_to_collections

Pytorch vs Lightning

  • Dataset, DataLoader: 在 Lightning 可以沿用, 或者使用 LightningDataModule, 多卡训练时, Dataloader 所需的 DistributedSampler 在 Lightning 中无需手动写

  • nn.Module: 在 Lightning 使用 LightningModule, 需要提供 forward, training_step, configure_optimizers 方法

  • 训练过程:

    • for loop: 在 Lightning 无需自己写 for loop

    • loss backward: 在 Lightning 中可以让 training_step 返回 loss, 自动进行 backward, 或者也可以手工写 backward 的执行过程

  • DDP, AMP: 在 Lightning 中用 Trainer 的初始化参数指定

  • 模型加载与保存: 最简单的用法是 Lightning 中用 Trainer 自动处理, 高级用法是初始化 Trainer 时增加 pytorch_lightning.callbacks.ModelCheckpoint 这个 callback, 更复杂的用法是关闭 Trainer 的模型保存功能(enable_checkpointing=False), 在 LightningModuletraining_epoch_end 或者 validation_epoch_end 中检测当前的 local_rank, 只在 local_rank 为0的进程上做保存模型的工作

  • 控制台打印: 可以使用 python 本身的 logging 模块进行, 参考官方文档

使用模板

pytorch_lightning.LightningModule

代码参考自: pytorch-lightning with huggingface transfomers

inference without pytorch-lightning package

为了更好地模块化, 建议采用如下方式组织代码

控制backward的逻辑

更多详细内容参考官方文档

如果需要控制backward的逻辑, 需要做以下事情

  • __init__中设置self.automatic_optimization=False

  • training_step中使用如下API:

    • 使用 optimizer=self.optimizers() 获取所有优化器

    • 使用 optimizer.zero_grad() 清除梯度

    • 使用 self.manual_backward(loss) 而不要使用 loss.backward()

    • 使用 optimizer.step()

多个optimizer与scheduler

建议 configure_optimizers 函数按如下方式返回

当前 rank、step、epoch 等

fit 函数伪代码(hook 编程)

从Lightning的实现上

  • hook 指的是类的一些特定方法, 例如: on_train_epoch_start, on_train_batch_start

  • callback 指的是含有这些特定方法的类。

具体来说,简易版的大致实现方式如下

备注: 下面的写法仅做示意 hook 编程的大致形式, 并非对齐 LightningModule 的真正实现

在理解了 hook/callback 之后, 可以参考官方文档中真正的实现流程理解 fit 函数的整个过程, 并在 LightningModule 的子类中覆盖这些方法或者传入自定义 Callback 类, 摘录如下:

training_step 出入参

  • 入参: 由以下这些 hook 最后的输出得到

  • 出参: 返回一个标量版的loss即可, 或者返回一个字典, 字典中有一个键值对为{"loss": loss}

training_step, validation_step, test_step, predict_step

  • training_step: 训练过程, batch中应包含x与y, 被trainer.fit调用

  • validation_step: 验证过程, batch中应包含x与y, 通常的每个epoch结束后被trainer.fit调用

  • test_step: 训练过程, batch中应包含x与y, 在trainer.fit中不被调用, 被trainer.test调用

  • predict_step: 在不定义predict_step的情况下, trainer.pred会调用model.forward, 否则会调用predict_step, 因此batch中只包含x即可

save checkpoint advanced

方式1 默认情况下, 会自动保存模型(只保存最新的), 参考官方文档

By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch, Checkpoints capture the exact value of all parameters used by a model. To disable automatic checkpointing, set this to False.

备注: 这种情况下默认会保存在lightening_logs/version_n/checkpoints 目录中, 且会保存许多其他非模型权重的东西

方式2 trainer 中传入的callbacks包含一个 ModelCheckpoint 实例, 参考官方文档。用于指定保存路径,保留多少个checkpoint,是否只保留权重,是否根据某个特定的监控值保存最优模型等

方式3 手写, 在各个hook中加上一些条件进行模型保存

方式4(不确定): 继承ModelCheckpoint

怎样控制 log 打印(未解决)

包含以下几个部分:

  • LightningModule.training_step 的返回值在 LightningModule.on_epoch_end 中保存的具体逻辑是?

  • 如何控制 Tensorboard 的打印内容与打印时机

  • 如何分别控制打印到控制台/保存到日志文件/输出到 Tensorboard 的信息

LightningModule 中的 self.log(...) 最终调用的是 trainer._results.log,而这个对象最终对应的是类似于trainer.fit_loop.epoch_loop._results,它是pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection 对象。

逻辑大约是module.log调用时,在 _results 中记录下信息,之后在 TrainingEpochLoopadvance函数中调用self.trainer._logger_connector.update_train_step_metrics()实际将 log 写入:具体进一步触发 LoggerConnector 中调用 self.log_metrics(self.trainer._results.metrics(not self._epoch_end_reached)['log']),根据是否需要写日志,再触发 TensorBoardLogger 的 log_metrics 函数。

  • trainer 的构造函数 __init__ 中有一个参数 log_every_n_steps 用于控制 update_train_step_metrics 函数将数据写入 tensorboard

  • 也就是 module.log() 最终修改了 trainer._results, 然后利用这个在某个时机写入了 tensorboard

备注:由于 TensorBoardLogger 中没有保存图像的函数,因此,如果想完全发挥它内部包含的 torch.utils.tensorboard import SummaryWriter。可以通过 TensorBoardLogger.experiment进行直接实现。

例如(待验证):

pytorch_lightning.Trainer

备注:相比于 huggingface transformers 中的 Trainer 类,官方文档中鼓励对其使用继承的方法重写一些方法。pytorch-lightning 中的推荐做法是直接使用 Trainer,而对 LightningModule 进行继承以及方法重写。

pytorch_lightning.LightningDataModule

第二部分:Lightning 源码阅读

OpenMMLab对pytorch-lightning也有一篇源码解读文章: https://zhuanlan.zhihu.com/p/389271556

目标: 理解如下代码片段的执行过程

LightningModule

父类

源码中关于 LightningModule 类的定义继承自了多个父类, 特别注意它也继承自torch.nn.Module。因此需要先对几个父类的代码做个了解

HyperparametersMixin

主要使用的方式如下参考官网例子: 通过调用 self.save_hyperparameters 方法, 将所有 __init__ 的传参保存到self._hparams

Trainer.__init__

Trainer 类没有父类, 直接继承自 object.

第 2 行代码: trainer = Trainer(...) 的源码如下:

首先解释一下这个装饰器的作用:

利用 os.environ.get 方法获取形如 PL_TRAINER_{XXX} 环境变量, 并用环境变量的值取代被装饰的函数(上面的例子中为Trainer.__init__函数)中的默认值. 即最终第 2 行代码参数设定的优先顺序为:

装饰器`_defaults_from_env_vars`的具体实现

理解上面的代码所需要的 Python 知识如下:

  • functools.wraps装饰器的作用

  • vars 内置函数的作用

  • typing.cast 函数的作用: 参考stackoverflow,此函数只在类型检查时起作用, 而在实际运行时什么都不做, 直接将入参返回

上述代码中进一步调用了如下两段代码

理解上面的代码所需要的 Python 知识如下:

  • 内置模块 inspect 相关知识: 此处仅用到 inspect.signature(callable).parameters, 用于获取函数或类的__init__方法定义中的变量名, 变量类型, 以及默认值

  • contextlib.suppress: 参考stackoverflow, 这两种写法基本等价:

  • 内置模块 ast.literal_eval: 此函数接受的参数为一个合法的字符串形式的python数据, 例如:

    即功能弱于内置函数eval, 官方文档中建议一切能用 ast.literal_eval 代替 eval 的地方, 都使用 ast.literal_eval, 无法替代的情况下, 应该选择其他实现方式, 而不能依赖 eval

接下来进入 Trainer.__init__ 函数的函数体, 完整源代码如下:

整体流程简要描述

总的来说基本上是一些为Trainer的属性赋值的操作

需要细致展开的部分如下:

TrainerState

涉及的源代码如下:

备注:LightningEnum 实际上继承自 python 原生的 Enum,增加了一个 from_str 方法,并允许它直接与字符串比较(__eq__函数),如果与枚举状态所对应的字符串相同,则返回 True

以下是更为具体的源代码

Connector(DataConnector,AcceleratorConnector,LoggerConnector,CallbackConnector,CheckpointConnector,SignalConnector)

Trainer.__init__ 函数依次进行了 DataConnector, AcceleratorConnector,LoggerConnector, CallbackConnector, CheckpointConnector, SignalConnector几个的初始化

推测: 这种XXXConnector类的作用基本上就是给Trainer添加一些属性, 不知道为啥不直接写在Trainer的内部(也许是写在Trainer内部, Trainer类的定义会变得很冗长?看源码长度pytorch_ligtening/trainer/trainer.py本身已有2000多行, 如果这些Connector也写在Trainer里,估计会更长)

这里几个 Connector__init__ 函数的初始化基本上只是设定 self.trainer=trainer,以及初始化一些状态,并无太多需要说明之处。唯一的例外是AcceleratorConnector做了许多工作(此处从略)。

Loop

为了更好地看出上面三行代码的执行逻辑,将其展开为(|-引导内部的调用顺序)如下

备注:一些简单的操作例如:self.min_steps=min_steps 被省略了

这里代码粗看上有些诡异:

  • FitLoop 中包含了一个TrainingEpochLoop, 继续而这个 TrainingEpochLoop包含TrainingBatchLoop和一个EvaluationLoop,似乎并不在一个层级上

  • *Progress 的定义均在 src/pytorch_lightning/trainer/progress.py 文件中, 主要作用是循环时记录下标和一些状态, 即记录 for i, x in enumerate(x_list) 中的 ix。其细致的源码分析如下:

CallbackConnector.on_trainer_init

主要执行逻辑是依次将如下默认Callback类添加至Trainer.callbacks中, 然后将这些callback按照类型进行重排序,先后顺序为:tuner_callbacks(BatchSizeFinder),other_callbacks, checkpoint_callbacks(ModelCheckpoint)。

这些默认的Callback及及相应控制的Trainer.__init__函数的入参默认值以及所包含的hook列举如下:

  • pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(默认有, enable_checkpointing=True): on_train_batch_end, on_train_epoch_end, on_train_start, on_validation_end

  • pytorch_lightning.callbacks.timer.Timer(默认无, max_time=None): on_fit_start, on_test_end, on_test_start, on_train_batch_end, on_train_end, on_train_epoch_end, on_train_start, on_validation_end, on_validation_start

  • pytorch_lightning.callbacks.progress.tqdm_progress.TQDMProgressBar(默认有, enable_progress_bar=True): on_predict_batch_end, on_predict_batch_start, on_predict_end, on_predict_start, on_sanity_check_end, on_sanity_check_start, on_test_batch_end, on_test_batch_start, on_test_end, on_test_start, on_train_batch_end, on_train_end, on_train_epoch_end, on_train_epoch_start, on_train_start, on_validation_batch_end, on_validation_batch_start, on_validation_end, on_validation_start

  • pytorch_lightning.callbacks.rich_model_summary.RichModelSummary(默认无, 除非enable_progress_bar=False且手动传入这个callback): on_fit_start

  • pytorch_lightning.callbacks.model_summary.ModelSummary(默认有, enable_model_summary): on_fit_start

  • pytorch_lightning.callbacks.gradient_accumulation_scheduler.GradientAccumulationScheduler(默认无, accumulate_grad_batches=None): on_train_epoch_start

  • pytorch_lightning.callbacks.fault_tolerance._FaultToleranceCheckpoint(必然有): on_exception

  • pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder(默认无, auto_lr_find=False, auto_scale_batch_size=False): on_fit_start, on_predict_start, on_test_start, on_validation_start

完整源代码如下:

DataConnector.on_trainer_init

作用是根据入参设定trainer的几个属性

除去异常处理的源代码如下

Tuner

Tuner的主要作用是自动尝试学习率与显存大小, 在`Trainer.__init__`函数中仅设定参数, 运行逻辑在 `Trainer.fit`函数中

Trainer.fit

完整源代码如下:

Trainer._call_callback_hooks(hook_name)

Trainer.fit 方法中也会多次调用Trainer._call_callback_hooks方法, 其完整源代码如下:

注: 为何要对"on_init_start", "on_init_end"这两个做单独的处理? 因为其他的hook_name都在Trainer.fit方法内部被调用,Trainer.fit方法的源代码如下:

其他补充

apply_to_collection

lightning_utilities 中包含一些递归应用函数(具体实现可以参看源码)

由于 pytorch-lightning 很多地方将这两个函数用于 dataclass 上,因此补充一个示例:

Last updated

Was this helpful?