pytorch-lightning
参考资料:
pytorch-lightning 101: 视频课程, 可以用来理解概念, 共 4 小节课程, 其中第 3 小节是 pytorch-lightning 的基本用法, 第 4 小节介绍了 pytorch-lightning 的实现细节
LightningLite: 轻量版的 pytorch-lighting, 目前(2022.9.29)并未完全成熟. 用于尽量做很少的代码改动, 快速将 pytorch 训练代码进行转换, 好处是可以很快地将写好的单 GPU 或 CPU 训练流程变得可以自动支持多卡训练, fp16 训练等.
本文档主要分为两大部分。第一部分从使用 Lightning 的角度,介绍使用方法,以样例为主,尽量不涉及过多的源码。第二部分主要解释 Lightning 的源代码,有利于更好地使用。
第一部分:Lightning
的使用
Lightning
的使用疑惑
LightningModule
中的self.log(...)
是指什么?(猜测用于传给torchmetric, 类似于tf的Metric) 似乎最终调用的是pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log()
此函数体内涉及到
lightning_utilities.core.apply_func.apply_to_collections
Pytorch vs Lightning
Dataset, DataLoader: 在 Lightning 可以沿用, 或者使用
LightningDataModule
, 多卡训练时, Dataloader 所需的 DistributedSampler 在 Lightning 中无需手动写nn.Module: 在 Lightning 使用
LightningModule
, 需要提供forward
,training_step
,configure_optimizers
方法训练过程:
for loop: 在 Lightning 无需自己写 for loop
loss backward: 在 Lightning 中可以让
training_step
返回 loss, 自动进行 backward, 或者也可以手工写 backward 的执行过程
DDP, AMP: 在 Lightning 中用
Trainer
的初始化参数指定模型加载与保存: 最简单的用法是 Lightning 中用
Trainer
自动处理, 高级用法是初始化Trainer
时增加pytorch_lightning.callbacks.ModelCheckpoint
这个 callback, 更复杂的用法是关闭Trainer
的模型保存功能(enable_checkpointing=False
), 在LightningModule
的training_epoch_end
或者validation_epoch_end
中检测当前的 local_rank, 只在 local_rank 为0的进程上做保存模型的工作控制台打印: 可以使用 python 本身的
logging
模块进行, 参考官方文档
使用模板
from pytorch_lightning import LightModule, Trainer
model = MyModule(...) # MyModule 继承自 LightModule
trainer = Trainer(...) # max_steps, min_steps 等参数
trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None,ckpt_path=None)
pytorch_lightning.LightningModule
代码参考自: pytorch-lightning with huggingface transfomers
class GLUETransformer(LightningModule):
def __init__(
self,
model_name_or_path: str,
num_labels: int,
task_name: str,
learning_rate: float = 2e-5,
adam_epsilon: float = 1e-8,
warmup_steps: int = 0,
weight_decay: float = 0.0,
train_batch_size: int = 32,
eval_batch_size: int = 32,
eval_splits: Optional[list] = None,
**kwargs,
):
super().__init__()
# 此函数用于保存__init__函数的入参至self.hparams, 含义是收集所有的超参数,推荐在此处使用
# 如果__init__函数的入参中有torch.nn.Module, 可以设置参数ignore将其忽略
self.save_hyperparameters()
self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
self.metric = datasets.load_metric(
"glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
)
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs[0]
# 返回一个标量版的loss即可, 或者返回一个字典, 字典中有一个键值对为{"loss": loss}
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(**batch)
val_loss, logits = outputs[:2]
if self.hparams.num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif self.hparams.num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
return {"loss": val_loss, "preds": preds, "labels": labels}
def validation_epoch_end(self, outputs):
if self.hparams.task_name == "mnli":
for i, output in enumerate(outputs):
# matched or mismatched
split = self.hparams.eval_splits[i].split("_")[-1]
preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in output]).mean()
self.log(f"val_loss_{split}", loss, prog_bar=True)
split_metrics = {
f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
}
self.log_dict(split_metrics, prog_bar=True)
return loss
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("val_loss", loss, prog_bar=True)
self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches, # 注意: 此处可以使用trainer的变量
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
inference without pytorch-lightning package
为了更好地模块化, 建议采用如下方式组织代码
# 存疑: 官方演示视频中此处继承的是LightningModule
class TorchModule(torch.nn.Module):
def __init__(self, **model_hparams):
...
def forward(self, input):
...
class PLModule(LightningModule)
def __init__(self, model, **kwargs):
self.model = model
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
...
def configure_optimizers(self):
...
控制backward的逻辑
更多详细内容参考官方文档
如果需要控制backward的逻辑, 需要做以下事情
在
__init__
中设置self.automatic_optimization=False
在
training_step
中使用如下API:使用
optimizer=self.optimizers()
获取所有优化器使用
optimizer.zero_grad()
清除梯度使用
self.manual_backward(loss)
而不要使用loss.backward()
使用
optimizer.step()
多个optimizer与scheduler
建议 configure_optimizers
函数按如下方式返回
(
{
"optimizer": optimizer_1,
"lr_scheduler": {
"scheduler": scheduler_1,
"interval": "step",
"frequency": 1,
}
},
{
"optimizer": optimizer_2,
"lr_scheduler": {
"scheduler": scheduler_2,
"interval": "step",
"frequency": 1,
}
},
)
当前 rank、step、epoch 等
def training_step(self, batch, batch_idx):
self.global_rank
self.local_rank
self.current_epoch # 当前epoch数
self.global_step # 全局步数
fit 函数伪代码(hook 编程)
从Lightning的实现上
hook 指的是类的一些特定方法, 例如:
on_train_epoch_start
,on_train_batch_start
。callback 指的是含有这些特定方法的类。
具体来说,简易版的大致实现方式如下
备注: 下面的写法仅做示意 hook 编程的大致形式, 并非对齐 LightningModule
的真正实现
# 不引入Trainer的时候, LightningModule的简易版实现
class MyLightningModule:
def __init__(self, callbacks):
self.callbacks = callbacks
def call_hook(self, hook_name, *args. **kwargs):
for callback in self.callbacks:
if hasattr(callback, hook_name):
func = getattr(callback, hook_name)
if callable(func):
func(*args. **kwargs) # 返回值怎么处理?
def fit(self, loader):
for batch in loader:
self.call_hook("on_before_batch", batch)
self.training_step() # 这个是 hook 吗?
self.call_hook("on_after_batch", batch)
def training_step(self):
raise NotImplementedError()
class MyCustomModule(MyLightningModule):
def __init__(self, callbacks):
super().__init__(self, callbacks)
def training_step(self):
# ...
class FirstCallback:
def on_before_batch(self):
# ...
def on_after_batch(self):
# ...
class SecondCallback:
def on_before_batch(self):
# ...
if __name__ == "__main__":
callbacks = [FirstCallback(), SecondCallback()]
model = MyLightningModule(callbacks)
loader = ...
model.fit(loader)
在理解了 hook/callback 之后, 可以参考官方文档中真正的实现流程理解 fit 函数的整个过程, 并在 LightningModule
的子类中覆盖这些方法或者传入自定义 Callback 类, 摘录如下:
def fit(self):
if global_rank == 0:
# prepare data is called on GLOBAL_ZERO only
prepare_data()
configure_callbacks()
with parallel(devices):
# devices can be GPUs, TPUs, ...
train_on_device(model)
def train_on_device(model):
# called PER DEVICE
on_fit_start()
setup("fit")
configure_optimizers()
# the sanity check runs here
on_train_start()
for epoch in epochs:
fit_loop()
on_train_end()
on_fit_end()
teardown("fit")
def fit_loop():
on_train_epoch_start()
for batch in train_dataloader():
on_train_batch_start()
on_before_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer()
training_step()
on_before_zero_grad()
optimizer_zero_grad()
on_before_backward()
backward()
on_after_backward()
on_before_optimizer_step()
configure_gradient_clipping()
optimizer_step()
on_train_batch_end()
if should_check_val:
val_loop()
# end training epoch
training_epoch_end()
on_train_epoch_end()
def val_loop():
on_validation_model_eval() # calls `model.eval()`
torch.set_grad_enabled(False)
on_validation_start()
on_validation_epoch_start()
val_outs = []
for batch_idx, batch in enumerate(val_dataloader()):
on_validation_batch_start(batch, batch_idx)
batch = on_before_batch_transfer(batch)
batch = transfer_batch_to_device(batch)
batch = on_after_batch_transfer(batch)
out = validation_step(batch, batch_idx)
on_validation_batch_end(batch, batch_idx)
val_outs.append(out)
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_validation_end()
# set up for train
on_validation_model_train() # calls `model.train()`
torch.set_grad_enabled(True)
training_step 出入参
入参: 由以下这些 hook 最后的输出得到
for batch in train_dataloader(): on_train_batch_start() on_before_batch_transfer() # 将batch中的tensor转移到相关的device上,如果默认的方法不能满足要求, 则可以重载这个函数 transfer_batch_to_device() on_after_batch_transfer() training_step()
出参: 返回一个标量版的loss即可, 或者返回一个字典, 字典中有一个键值对为{"loss": loss}
@dataclass
class OutputResult:
def asdict(self) -> Dict[str, Any]:
raise NotImplementedError
# src/pytorch_lightning/loops/optimization/optimizer_loop.py
class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
"""Runs over a sequence of optimizers.
This loop implements what is known in Lightning as Automatic Optimization.
"""
output_result_cls = ClosureResult
def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
...
# 注: training_step_output即为training_step的返回结果
result = self.output_result_cls.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
...
@dataclass
class ClosureResult(OutputResult):
"""A container to hold the result of a :class:`Closure` call.
It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`.
Attributes:
closure_loss: The loss with a graph attached.
loss: A detached copy of the closure loss.
extra: Any keys other than the loss returned.
"""
closure_loss: Optional[Tensor]
loss: Optional[Tensor] = field(init=False, default=None)
extra: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
self._clone_loss()
def _clone_loss(self) -> None:
if self.closure_loss is not None:
# the loss will get scaled for amp. avoid any modifications to it
self.loss = self.closure_loss.detach().clone()
@classmethod
def from_training_step_output(
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
) -> "ClosureResult":
closure_loss, extra = None, {}
if isinstance(training_step_output, dict):
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
closure_loss = training_step_output.get("loss")
if closure_loss is None:
raise MisconfigurationException(
"In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
)
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
elif isinstance(training_step_output, Tensor):
closure_loss = training_step_output
elif training_step_output is not None:
raise MisconfigurationException(
"In automatic optimization, `training_step` must return a Tensor, "
"a dict, or None (where the step will be skipped)."
)
if closure_loss is not None:
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
# note: avoid in-place operation `x /= y` here on purpose
closure_loss = closure_loss / normalize
return cls(closure_loss, extra=extra)
def asdict(self) -> Dict[str, Any]:
return {"loss": self.loss, **self.extra}
training_step
, validation_step
, test_step
, predict_step
training_step
, validation_step
, test_step
, predict_step
training_step: 训练过程, batch中应包含x与y, 被trainer.fit调用
validation_step: 验证过程, batch中应包含x与y, 通常的每个epoch结束后被trainer.fit调用
test_step: 训练过程, batch中应包含x与y, 在trainer.fit中不被调用, 被trainer.test调用
predict_step: 在不定义predict_step的情况下, trainer.pred会调用model.forward, 否则会调用predict_step, 因此batch中只包含x即可
def training_step(batch, batch_idx, optimizer_idx, hiddens):
# 返回可以是三种:(1) loss tensor (2) dict, 但必须包含"loss"这个key (3) None, 跳过此次training_step的过程, 一般用于手动backward
def validation_step(batch, batch_idx, dataloader_idx):
# 返回可以是(1)tensor (2)dict of tensor (3) None, 跳过此次validation_step
def test_step(batch, batch_idx, dataloader_id):
# 返回可以是(1)tensor (2)dict of tensor (3) None, 跳过此次test_step
def predict_step(batch, batch_idx, dataloader_id):
# 返回是Any
save checkpoint advanced
方式1 默认情况下, 会自动保存模型(只保存最新的), 参考官方文档
By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch, Checkpoints capture the exact value of all parameters used by a model. To disable automatic checkpointing, set this to False.
# default used by Trainer, saves the most recent model to a single checkpoint after each epoch
trainer = Trainer(enable_checkpointing=True)
# turn off automatic checkpointing
trainer = Trainer(enable_checkpointing=False)
备注: 这种情况下默认会保存在lightening_logs/version_n/checkpoints
目录中, 且会保存许多其他非模型权重的东西
import torch
d = torch.load("lightening_logs/version_n/checkpoints/xxx.ckpt")
d.keys()
# epoch, global_step, pytorch_lightening_version, state_dict, loops, callbacks, optimizer_states, lr_schedulers
# 其中state_dict为模型的权重
方式2 trainer 中传入的callbacks包含一个 ModelCheckpoint
实例, 参考官方文档。用于指定保存路径,保留多少个checkpoint,是否只保留权重,是否根据某个特定的监控值保存最优模型等
from pytorch_lightning.callbacks import ModelCheckpoint
# Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
# Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])
方式3 手写, 在各个hook中加上一些条件进行模型保存
from pytorch_lightning import LightningModule, Trainer
import torch
class MyModel:
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(64, 32)
def training_epoch_end(self, training_outs):
if self.current_epoch in [0, 2] and self.local_rank == 0:
torch.save(self.layer.state_dict(), f"epoch_{self.current_epoch}.pth")
model = MyModel()
trainer = Trainer(max_epochs=4, gpus=2, enable_checkpointing=False)
方式4(不确定): 继承ModelCheckpoint
怎样控制 log 打印(未解决)
包含以下几个部分:
LightningModule.training_step
的返回值在LightningModule.on_epoch_end
中保存的具体逻辑是?如何控制 Tensorboard 的打印内容与打印时机
如何分别控制打印到控制台/保存到日志文件/输出到 Tensorboard 的信息
LightningModule
中的 self.log(...)
最终调用的是 trainer._results.log
,而这个对象最终对应的是类似于trainer.fit_loop.epoch_loop._results
,它是pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection
对象。
逻辑大约是module.log
调用时,在 _results
中记录下信息,之后在 TrainingEpochLoop
的 advance
函数中调用self.trainer._logger_connector.update_train_step_metrics()
实际将 log 写入:具体进一步触发 LoggerConnector
中调用 self.log_metrics(self.trainer._results.metrics(not self._epoch_end_reached)['log'])
,根据是否需要写日志,再触发 TensorBoardLogger 的 log_metrics
函数。
trainer
的构造函数__init__
中有一个参数log_every_n_steps
用于控制update_train_step_metrics
函数将数据写入 tensorboard也就是
module.log()
最终修改了trainer._results
, 然后利用这个在某个时机写入了 tensorboard
备注:由于 TensorBoardLogger
中没有保存图像的函数,因此,如果想完全发挥它内部包含的 torch.utils.tensorboard import SummaryWriter
。可以通过 TensorBoardLogger.experiment
进行直接实现。
例如(待验证):
import pytorch_lightning as pl
# 在 LightningModule 中即刻将log写入tensorboard
def training_step(...):
for logger in self.trainer.loggers:
if isinstance(logger, pl.loggers.TensorBoardLogger):
logger.experiment.add_image('a/ori_image', ori_image, global_idx)
pytorch_lightning.Trainer
trainer = Trainer()
trainer.fit(model)
备注:相比于 huggingface transformers 中的 Trainer
类,官方文档中鼓励对其使用继承的方法重写一些方法。pytorch-lightning 中的推荐做法是直接使用 Trainer
,而对 LightningModule
进行继承以及方法重写。
pytorch_lightning.LightningDataModule
import torch
from pytorch_lightning import LightningDataModule
class MyDataset(torch.utils.data.Dataset):
...
def __getitem__(self, idx):
return transform(data[idx])
class MyDataModule(LightningDataModule):
def __init__(
self,
model_name_or_path: str,
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def setup(self, stage: str):
# 此部分代码在多卡场景下会被每个进程执行, 因此可以设置变量
# 建议设置self.dataset
self.datasets = {
"train": MyDataset(...)
"val": MyDataset(...)
}
def prepare_data(self):
# 此部分代码仅在rank0上运行, 建议不要设置类的属性
# 建议做一些数据的转换工作, 例如切分数据至train/val文件夹,将数据tokenizer化保存
pass
def train_dataloader(self):
# 此部分代码在多卡场景下会被每个进程执行
# 建议dataset在self.setup方法中设定, 此处直接使用torch.utils.data.DataLoader进行包装
return torch.utils.data.DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.dataset["val"], batch_size=self.train_batch_size, shuffle=False)
第二部分:Lightning
源码阅读
Lightning
源码阅读OpenMMLab对pytorch-lightning也有一篇源码解读文章: https://zhuanlan.zhihu.com/p/389271556
目标: 理解如下代码片段的执行过程
from pytorch_lightning import LightModule, Trainer
model = MyModule(...) # MyModule 继承自 LightModule
trainer = Trainer(...) # max_steps, min_steps 等参数
trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None,ckpt_path=None)
LightningModule
LightningModule
父类
源码中关于 LightningModule
类的定义继承自了多个父类, 特别注意它也继承自torch.nn.Module
。因此需要先对几个父类的代码做个了解
class LightningModule(
_DeviceDtypeModuleMixin,
HyperparametersMixin,
ModelIO,
ModelHooks,
DataHooks,
CheckpointHooks,
Module,
):
...
HyperparametersMixin
主要使用的方式如下参考官网例子: 通过调用 self.save_hyperparameters
方法, 将所有 __init__
的传参保存到self._hparams
中
class MyModule(LightningModule):
def __init__(
self,
model_name_or_path: str,
num_labels: int,
learning_rate: float = 2e-5,
train_batch_size: int = 32,
**kwargs,
):
super().__init__()
# MyModule中其他地方可以使用self.hparams.num_labels
self.save_hyperparameters(ignore=["model_name_or_path"])
model = MyModule("x.pth", 10)
print(model.hparams) # pytorch_lightning.utilities.parsing.AttributeDict
# {"num_labels": 10, "learning_rate": 2e-5, "train_batch_size": 32}
Trainer.__init__
Trainer.__init__
Trainer
类没有父类, 直接继承自 object
.
第 2 行代码: trainer = Trainer(...)
的源码如下:
# src/pytorch_lightning/trainer/trainer.py:Trainer
@_defaults_from_env_vars
def __init__(self, logger=True, ...) # 共有约50个参数
...
首先解释一下这个装饰器的作用:
利用 os.environ.get
方法获取形如 PL_TRAINER_{XXX}
环境变量, 并用环境变量的值取代被装饰的函数(上面的例子中为Trainer.__init__
函数)中的默认值. 即最终第 2 行代码参数设定的优先顺序为:
实参 > 环境变量 > 函数定义中形参的默认值
装饰器`_defaults_from_env_vars`的具体实现
# src/pytorch_lightning/utilities/argparse.py
def _defaults_from_env_vars(fn: _T) -> _T:
@wraps(fn) # 注: functools.wraps
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
cls = self.__class__ # get the class
if args: # in case any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update(dict(zip(cls_arg_names, args))) # 注: 此处的kwargs为实参
env_variables = vars(parse_env_variables(cls)) # 注: 此处为从环境变量处解析得到的默认值
# update the kwargs by env variables
# 注: 这里第2项中的键值对会覆盖第1项的键值对, 因此优先级为实参>环境变量>函数定义中的默认值
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
# all args were already moved to kwargs
return fn(self, **kwargs)
return cast(_T, insert_env_defaults) # 注: typing.cast
理解上面的代码所需要的 Python 知识如下:
functools.wraps
装饰器的作用vars
内置函数的作用typing.cast
函数的作用: 参考stackoverflow,此函数只在类型检查时起作用, 而在实际运行时什么都不做, 直接将入参返回
上述代码中进一步调用了如下两段代码
# src/pytorch_lightning/utilities/argparse.py
def parse_env_variables(cls: _ARGPARSE_CLS, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.
Examples:
>>> from pytorch_lightning import Trainer
>>> parse_env_variables(Trainer)
Namespace()
>>> import os
>>> os.environ["PL_TRAINER_GPUS"] = '42'
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
>>> parse_env_variables(Trainer)
Namespace(gpus=42)
>>> del os.environ["PL_TRAINER_GPUS"]
"""
cls_arg_defaults = get_init_arguments_and_types(cls)
env_args = {}
for arg_name, _, _ in cls_arg_defaults:
env = template % {"cls_name": cls.__name__.upper(), "cls_argument": arg_name.upper()}
val = os.environ.get(env)
if not (val is None or val == ""):
# todo: specify the possible exception
with suppress(Exception): # 注: contextlib.suppress
# converting to native types like int/float/bool
val = literal_eval(val) # 注: ast.literal_eval
env_args[arg_name] = val
return Namespace(**env_args)
def get_init_arguments_and_types(cls: _ARGPARSE_CLS) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the class signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
Examples:
>>> from pytorch_lightning import Trainer
>>> args = get_init_arguments_and_types(Trainer)
"""
cls_default_params = inspect.signature(cls).parameters # 注
name_type_default = []
for arg in cls_default_params:
arg_type = cls_default_params[arg].annotation
arg_default = cls_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except (AttributeError, TypeError):
arg_types = (arg_type,)
name_type_default.append((arg, arg_types, arg_default))
return name_type_default
理解上面的代码所需要的 Python 知识如下:
内置模块
inspect
相关知识: 此处仅用到inspect.signature(callable).parameters
, 用于获取函数或类的__init__
方法定义中的变量名, 变量类型, 以及默认值contextlib.suppress
: 参考stackoverflow, 这两种写法基本等价:# 写法一: with contextlib.suppress(ValueError): x = int('a') # 写法二: try: x = int('a') except ValueError: pass
内置模块
ast.literal_eval
: 此函数接受的参数为一个合法的字符串形式的python数据, 例如:ast.literal_eval("['a', 'b']") # 返回列表: ["a", "b"] ast.literal_eval("'a'") # 返回字符串: "a" ast.literal_eval("1") # 返回整数: 1 ast.literal_eval("1.2") # 返回浮点数: 1.2 ast.literal_eval("1+1") # 报错 ast.literal_eval("a") # 报错
即功能弱于内置函数
eval
, 官方文档中建议一切能用ast.literal_eval
代替eval
的地方, 都使用ast.literal_eval
, 无法替代的情况下, 应该选择其他实现方式, 而不能依赖eval
接下来进入 Trainer.__init__
函数的函数体, 完整源代码如下:
# src/pytorch_lightning/trainer/trainer.py
@_defaults_from_env_vars
def __init__(self, logger, ....): # 注: 一共有约50个参数
super().__init__()
# 即执行: torch._C._log_api_usage_once("lightning.trainer." + "init")
# 在环境变量为PYTORCH_API_USAGE_STDERR=1时才打印信息
Trainer._log_api_event("init")
# 此处的 log 是本文件的"全局"变量
# log = logging.getLogger(__name__)
log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
# 见说明 `TrainerState`
self.state = TrainerState()
# 见说明 `Connector`
# init connectors
self._data_connector = DataConnector(self, multiple_trainloader_mode)
self._accelerator_connector = AcceleratorConnector(
num_processes=num_processes,
devices=devices,
tpu_cores=tpu_cores,
ipus=ipus,
accelerator=accelerator,
strategy=strategy,
gpus=gpus,
num_nodes=num_nodes,
sync_batchnorm=sync_batchnorm,
benchmark=benchmark,
replace_sampler_ddp=replace_sampler_ddp,
deterministic=deterministic,
auto_select_gpus=auto_select_gpus,
precision=precision,
amp_type=amp_backend,
amp_level=amp_level,
plugins=plugins,
)
self._logger_connector = LoggerConnector(self)
self._callback_connector = CallbackConnector(self)
self._checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
self._signal_connector = SignalConnector(self)
# 见下面说明 `Tuner`
self.tuner = Tuner(self)
# 见下面说明 `Loop`
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
# 注: 执行的函数体为: fit_loop.epoch_loop=training_epoch_loop
fit_loop.connect(epoch_loop=training_epoch_loop)
# default .fit() loop
self.fit_loop = fit_loop
# default .validate() loop
self.validate_loop = EvaluationLoop()
# default .test() loop
self.test_loop = EvaluationLoop()
# default .predict() loop
self.predict_loop = PredictionLoop()
# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
self._ckpt_path: Optional[str] = None
# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
self._callback_connector.on_trainer_init(
callbacks,
enable_checkpointing,
enable_progress_bar,
default_root_dir,
enable_model_summary,
max_time,
accumulate_grad_batches,
)
# hook
# 见下面说明 `_call_callback_hooks`
# V1.8版本对这个做了移除处理, 默认的几个Callback都没有这个hook
self._call_callback_hooks("on_init_start")
# init data flags
# 有点诡异, 没有赋值?
self.check_val_every_n_epoch: int
self._data_connector.on_trainer_init(
val_check_interval,
reload_dataloaders_every_n_epochs,
check_val_every_n_epoch,
)
# gradient clipping
if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type(
gradient_clip_algorithm.lower()
):
raise MisconfigurationException(
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
)
# gradient norm tracking
if track_grad_norm != -1 and not (
(isinstance(track_grad_norm, (int, float)) or track_grad_norm == "inf") and float(track_grad_norm) > 0
):
raise MisconfigurationException(
f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}."
)
self.gradient_clip_val: Union[int, float] = gradient_clip_val
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
)
self.track_grad_norm: float = float(track_grad_norm)
self._detect_anomaly: bool = detect_anomaly
# 见下面说明 `_setup_on_init`
self._setup_on_init()
# configure tuner
# 见下面说明 `Tuner`
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)
# configure profiler
# 见下面说明 `setup._init_profiler`
setup._init_profiler(self, profiler)
# init logger flags
self._loggers: List[Logger]
self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)
# init debugging flags
self.val_check_interval: Union[int, float]
self.num_sanity_val_steps: Union[float, int]
# 见下面说明 `setup._init_debugging_flags`
setup._init_debugging_flags(
self,
limit_train_batches,
limit_val_batches,
limit_test_batches,
limit_predict_batches,
fast_dev_run,
overfit_batches,
val_check_interval,
num_sanity_val_steps,
)
# Callback system
self._call_callback_hooks("on_init_end")
整体流程简要描述:
总的来说基本上是一些为Trainer
的属性赋值的操作
self.state = TrainerState() # 后续调用fit/test等函数时会对这个self.state进行设置
# 初始化DataConnector,AcceleratorConnector,LoggerConnector,CallbackConnector,CheckpointConnector,SignalConnector, 代码从略, 除了AcceleratorConnector进行了一些实质性的准备工作外(例如DDP的一些诸如:dist.init_process_group的操作,是否实际执行存疑,待后续明确), 其余基本上都只是对属性值进行了一些初始化
# 初始化几个loop,实际上仅根据入参设定了一些参数, 涉及到的几个loop嵌套关系见后文说明
self.fit_loop = fit_loop
self.validate_loop = EvaluationLoop()
self.test_loop = EvaluationLoop()
self.predict_loop = PredictionLoop()
# 主要执行逻辑是依次将如下默认Callback类添加至`Trainer.callbacks`中, 然后将这些callback按照类型进行重排序,先后顺序为:tuner_callbacks(BatchSizeFinder),other_callbacks, checkpoint_callbacks(ModelCheckpoint)。
self._callback_connector.on_trainer_init(...)
# 依次调用所有callback的"on_init_start" hook, lightning v1.8对这一过程做了移除, 可参考关于Trainer.fit的代码解析
# self._call_callback_hooks("on_init_start")
# 作用是根据入参设定trainer的几个属性
self._data_connector.on_trainer_init(...)
# 设定一些属性, 并在主进程上打印一些GPU/TPU是否可用, 是否使用的日志
self._setup_on_init()
# 作用是根据入参设定trainer的几个属性, self.auto_lr_find = auto_lr_find
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)
# 初始化self.profiler, 默认为初始化一个PassThroughProfiler(profiler=None)
setup._init_profiler(self, profiler)
# 设定trainer.loggers(列表): 如果参数为logger默认值True,则创建TensorBoardLogger, 否则按照logger设定
self._logger_connector.on_trainer_init(...)
# 设定一些debug用的参数, 作用未知?
setup._init_debugging_flags(...)
需要细致展开的部分如下:
TrainerState
# from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
self.state = TrainerState()
涉及的源代码如下:
备注:LightningEnum
实际上继承自 python 原生的 Enum
,增加了一个 from_str
方法,并允许它直接与字符串比较(__eq__
函数),如果与枚举状态所对应的字符串相同,则返回 True
。
# src/lightning_lite/utilities/enums.py
from pytorch_lightning.utilities import LightningEnum
@dataclass
class TrainerState:
"""Dataclass to encapsulate the current :class:`~pytorch_lightning.trainer.trainer.Trainer` state."""
# trainer的运行状态: "initializing", "running", "finished", "interrupted"
status: TrainerStatus = TrainerStatus.INITIALIZING
# "fit", "validate", "test", "predict", "tune"
# 与trainer.fit/validate/test/predict/tune直接绑定
fn: Optional[TrainerFn] = None
# "sanity_check", "train", "validate", "test", "predict", "tune"
# trainer.fit函数内的具体状态, 会依次变为: "sanity_check", "train", "validate"
stage: Optional[RunningStage] = None
# detect the fault tolerant flag
# 这个不确定是用来做什么的
_fault_tolerant_mode: _FaultTolerantMode = field(default_factory=_FaultTolerantMode.detect_current_mode)
@property
def finished(self) -> bool:
return self.status == TrainerStatus.FINISHED
@property
def stopped(self) -> bool:
return self.status.stopped
以下是更为具体的源代码
class TrainerStatus(LightningEnum):
"""Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`"""
INITIALIZING = "initializing" # trainer creation
RUNNING = "running"
FINISHED = "finished"
INTERRUPTED = "interrupted"
@property
def stopped(self) -> bool:
return self in (self.FINISHED, self.INTERRUPTED)
class TrainerFn(LightningEnum):
"""
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
"""
FITTING = "fit"
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"
@property
def _setup_fn(self) -> "TrainerFn":
"""``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders.
This is used for the ``setup()`` and ``teardown()`` hooks
"""
return TrainerFn.FITTING if self == TrainerFn.TUNING else self
class RunningStage(LightningEnum):
"""Enum for the current running stage.
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
More than one running stage value can be set while a :class:`TrainerFn` is running:
- ``TrainerFn.FITTING`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
"""
TRAINING = "train"
SANITY_CHECKING = "sanity_check"
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"
@property
def evaluating(self) -> bool:
return self in (self.VALIDATING, self.TESTING)
@property
def dataloader_prefix(self) -> Optional[str]:
if self in (self.SANITY_CHECKING, self.TUNING):
return None
if self == self.VALIDATING:
return "val"
return self.value
Connector(DataConnector,AcceleratorConnector,LoggerConnector,CallbackConnector,CheckpointConnector,SignalConnector)
Trainer.__init__
函数依次进行了 DataConnector
, AcceleratorConnector
,LoggerConnector
, CallbackConnector
, CheckpointConnector
, SignalConnector
几个的初始化
推测: 这种XXXConnector
类的作用基本上就是给Trainer
添加一些属性, 不知道为啥不直接写在Trainer
的内部(也许是写在Trainer内部, Trainer类的定义会变得很冗长?看源码长度pytorch_ligtening/trainer/trainer.py
本身已有2000多行, 如果这些Connector也写在Trainer里,估计会更长)
这里几个 Connector
的 __init__
函数的初始化基本上只是设定 self.trainer=trainer
,以及初始化一些状态,并无太多需要说明之处。唯一的例外是AcceleratorConnector
做了许多工作(此处从略)。
Loop
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
# 注: 执行的函数体为: fit_loop.epoch_loop=training_epoch_loop
fit_loop.connect(epoch_loop=training_epoch_loop)
为了更好地看出上面三行代码的执行逻辑,将其展开为(|-
引导内部的调用顺序)如下
备注:一些简单的操作例如:self.min_steps=min_steps
被省略了
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
|-self.epoch_loop = TrainingEpochLoop()
| |-self.batch_progress = BatchProgress()
| |-self.scheduler_progress = SchedulerProgress()
| |-self.batch_loop = TrainingBatchLoop()
| | |-# 内部保存一个长度为20的memory
| | |-self.accumulated_loss = TensorRunningAccum(window_length=20)
| | |-self.running_loss = TensorRunningAccum(window_length=20)
| | |-self.optimizer_loop = OptimizerLoop()
| | | |-self.optim_progress: OptimizationProgress = OptimizationProgress()
| | |-self.manual_loop = ManualOptimization()
| | |-self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
| |-self.val_loop = loops.EvaluationLoop(verbose=False)
| |-self.epoch_loop = EvaluationEpochLoop()
| |-self.batch_progress = BatchProgress()
|-self.epoch_progress = Progress()
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
# fit_loop.epoch_loop=training_epoch_loop
fit_loop.connect(epoch_loop=training_epoch_loop)
self.fit_loop = fit_loop
self.validate_loop = EvaluationLoop()
self.test_loop = EvaluationLoop()
self.predict_loop = PredictionLoop()
|-self.epoch_loop = PredictionEpochLoop()
|-self.batch_progress = Progress()
这里代码粗看上有些诡异:
FitLoop
中包含了一个TrainingEpochLoop
, 继续而这个TrainingEpochLoop
包含TrainingBatchLoop
和一个EvaluationLoop
,似乎并不在一个层级上TrainingEpochLoop -- TrainingBatchLoop -- OptimizerLoop \_ EvaluationLoop -- EvaluationEpochLoop(没有EvaluationBatchLoop?)
*Progress
的定义均在src/pytorch_lightning/trainer/progress.py
文件中, 主要作用是循环时记录下标和一些状态, 即记录for i, x in enumerate(x_list)
中的i
与x
。其细致的源码分析如下:
CallbackConnector.on_trainer_init
# trainer.__init__函数内部调用了这个函数
self._callback_connector.on_trainer_init(...)
主要执行逻辑是依次将如下默认Callback类添加至Trainer.callbacks
中, 然后将这些callback按照类型进行重排序,先后顺序为:tuner_callbacks(BatchSizeFinder),other_callbacks, checkpoint_callbacks(ModelCheckpoint)。
这些默认的Callback及及相应控制的Trainer.__init__
函数的入参默认值以及所包含的hook列举如下:
pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(默认有, enable_checkpointing=True): on_train_batch_end, on_train_epoch_end, on_train_start, on_validation_end
pytorch_lightning.callbacks.timer.Timer(默认无, max_time=None): on_fit_start, on_test_end, on_test_start, on_train_batch_end, on_train_end, on_train_epoch_end, on_train_start, on_validation_end, on_validation_start
pytorch_lightning.callbacks.progress.tqdm_progress.TQDMProgressBar(默认有, enable_progress_bar=True): on_predict_batch_end, on_predict_batch_start, on_predict_end, on_predict_start, on_sanity_check_end, on_sanity_check_start, on_test_batch_end, on_test_batch_start, on_test_end, on_test_start, on_train_batch_end, on_train_end, on_train_epoch_end, on_train_epoch_start, on_train_start, on_validation_batch_end, on_validation_batch_start, on_validation_end, on_validation_start
pytorch_lightning.callbacks.rich_model_summary.RichModelSummary(默认无, 除非enable_progress_bar=False且手动传入这个callback): on_fit_start
pytorch_lightning.callbacks.model_summary.ModelSummary(默认有, enable_model_summary): on_fit_start
pytorch_lightning.callbacks.gradient_accumulation_scheduler.GradientAccumulationScheduler(默认无, accumulate_grad_batches=None): on_train_epoch_start
pytorch_lightning.callbacks.fault_tolerance._FaultToleranceCheckpoint(必然有): on_exception
pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder(默认无, auto_lr_find=False, auto_scale_batch_size=False): on_fit_start, on_predict_start, on_test_start, on_validation_start
完整源代码如下:
class CallbackConnector:
...
def on_trainer_init(
self,
callbacks: Optional[Union[List[Callback], Callback]],
enable_checkpointing: bool,
enable_progress_bar: bool,
default_root_dir: Optional[str],
enable_model_summary: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
) -> None:
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
# init callbacks
if isinstance(callbacks, Callback):
callbacks = [callbacks]
self.trainer.callbacks = callbacks or []
# configure checkpoint callback
# pass through the required args to figure out defaults
# 注: self.trainer.callbacks增加pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
self._configure_checkpoint_callbacks(enable_checkpointing)
# configure the timer callback.
# responsible to stop the training when max_time is reached.
# 注: max_time为程序最大运行时间,如果设置,则增加pytorch_lightning.callbacks.timer.Timer
self._configure_timer_callback(max_time)
# init progress bar
# 注: 增加pytorch_lightning.callbacks.progress.tqdm_progress.TQDMProgressBar
self._configure_progress_bar(enable_progress_bar)
# configure the ModelSummary callback
# 注: 增加pytorch_lightning.callbacks.model_summary.ModelSummary
self._configure_model_summary_callback(enable_model_summary)
# accumulated grads
# 注: 如果设置了梯度累累积, 则设置pytorch_lightning.callback.gradient_accumulation_scheduler.GradientAccumulationScheduler
self._configure_accumulated_gradients(accumulate_grad_batches)
# 注: ...
if self.trainer.state._fault_tolerant_mode.is_enabled:
self._configure_fault_tolerance_callbacks()
# 注: 一般是空列表, 可以从Entrypoint中加入callback, 见官网说明
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#entry-points
self.trainer.callbacks.extend(_configure_external_callbacks())
# push all model checkpoint callbacks to the end
# it is important that these are the last callbacks to run
self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks) # 注: 见下面的函数定义, Checkpoint被排放至最后
@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
tuner_callbacks: List[Callback] = []
other_callbacks: List[Callback] = []
checkpoint_callbacks: List[Callback] = []
for cb in callbacks:
if isinstance(cb, BatchSizeFinder):
tuner_callbacks.append(cb)
elif isinstance(cb, Checkpoint):
checkpoint_callbacks.append(cb)
else:
other_callbacks.append(cb)
return tuner_callbacks + other_callbacks + checkpoint_callbacks
DataConnector.on_trainer_init
# trainer.__init__函数内部调用了这个函数
self._data_connector.on_trainer_init(...)
作用是根据入参设定trainer的几个属性
除去异常处理的源代码如下
def on_trainer_init(
self,
val_check_interval: Optional[Union[int, float]],
reload_dataloaders_every_n_epochs: int,
check_val_every_n_epoch: Optional[int],
) -> None:
self.trainer.datamodule = None
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False
Trainer.fit
Trainer.fit
完整源代码如下:
def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
r"""
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
as all errors should funnel through them
Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
if trainer.strategy.launcher is not None:
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
else:
return trainer_fn(*args, **kwargs)
except _TunerExitException:
trainer._call_teardown_hook()
trainer._teardown()
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
trainer.state.status = TrainerStatus.INTERRUPTED
trainer._call_callback_hooks("on_exception", exception)
for logger in trainer.loggers:
logger.finalize("failed")
except BaseException as exception:
trainer.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and trainer.world_size > 1:
# try syncing remaining processes, kill otherwise
trainer.strategy.reconciliate_processes(traceback.format_exc())
trainer._call_callback_hooks("on_exception", exception)
for logger in trainer.loggers:
logger.finalize("failed")
trainer._teardown()
# teardown might access the stage so we reset it after
trainer.state.stage = None
raise
class Trainer:
def fit(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
r"""
Runs the full optimization routine.
Args:
model: Model to fit.
train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
# self.strategy.lightning_module即是_lightning_module
self.strategy._lightning_module = model
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
# 实际上执行发生在这, 但可能被self.strategy包裹
def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
Trainer._log_api_event("fit")
log.detail(f"{self.__class__.__name__}: trainer fit stage")
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
train_dataloaders = None
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`"
)
# links data to the trainer
self._data_connector.attach_data(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)
# TODO: ckpt_path only in v2.0
ckpt_path = ckpt_path or self.resume_from_checkpoint
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
self.state.fn,
ckpt_path, # type: ignore[arg-type]
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=self.ckpt_path)
assert self.state.stopped
self.training = False
return
Trainer._call_callback_hooks(hook_name)
Trainer.fit
方法中也会多次调用Trainer._call_callback_hooks
方法, 其完整源代码如下:
class Trainer:
def _call_callback_hooks(
self,
hook_name: str,
*args: Any,
**kwargs: Any,
) -> None:
log.debug(f"{self.__class__.__name__}: calling callback hook: {hook_name}")
# TODO: remove if block in v1.8
if hook_name in ("on_init_start", "on_init_end"):
# these `Callback` hooks are the only ones that do not take a lightning module.
# we also don't profile bc profiler hasn't been set yet
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
fn(self, *args, **kwargs)
return
# 注: self.lightning_module的定义见如下注解
pl_module = self.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
fn(self, self.lightning_module, *args, **kwargs)
if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
注: 为何要对"on_init_start", "on_init_end"
这两个做单独的处理? 因为其他的hook_name
都在Trainer.fit
方法内部被调用,Trainer.fit
方法的源代码如下:
class Trainer:
def fit(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
self.strategy._lightning_module = model
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
@property
def lightning_module(self) -> "pl.LightningModule":
# TODO: this is actually an optional return
return self.strategy.lightning_module
# src/pytorch_lightning/strategies/strategy.py
class Strategy(ABC):
@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
"""Returns the pure LightningModule without potential wrappers."""
return self._lightning_module
其他补充
apply_to_collection
lightning_utilities
中包含一些递归应用函数(具体实现可以参看源码)
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
data = [[1, 1.2], 2, 2.3, None]
data2 = [[1, 3], 3., 2, None]
# 这个函数也适用于字典和 dataclass, lambda 函数仅应用于指定的类型: int
apply_to_collection(data, int, lambda x: x * 2, include_none=False)
# [[2, 1.2], 4, 2.3]
apply_to_collections(data, data2, int, lambda x, y: x+y)
# [[2, 1.2], 5.0, 2.3, None]
由于 pytorch-lightning
很多地方将这两个函数用于 dataclass
上,因此补充一个示例:
from dataclasses import dataclass
@dataclass
class A:
a: int
b: float
c: str
d: list
x = A(a=2, b=2.3, c="str", d=[1, 2, 3.9])
apply_to_collection(x, int, lambda x: x * 2)
# A(a=4, b=2.3, c='str', d=[2, 4, 3.9])
Last updated