# pytorch-lightning

参考资料：

* [pytorch-lightning 101](https://www.youtube.com/playlist?list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2): 视频课程, 可以用来理解概念, 共 4 小节课程, 其中第 3 小节是 pytorch-lightning 的基本用法, 第 4 小节介绍了 pytorch-lightning 的实现细节
* [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.html): 轻量版的 pytorch-lighting, 目前(2022.9.29)并未完全成熟. 用于尽量做很少的代码改动, 快速将 pytorch 训练代码进行转换, 好处是可以很快地将写好的单 GPU 或 CPU 训练流程变得可以自动支持多卡训练, fp16 训练等.
* [pytorch-lightning with huggingface transfomers](https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/text-transformers.html)

本文档主要分为两大部分。第一部分从使用 Lightning 的角度，介绍使用方法，以样例为主，尽量不涉及过多的源码。第二部分主要解释 Lightning 的源代码，有利于更好地使用。

## 第一部分：`Lightning` 的使用

### 疑惑

* `LightningModule` 中的 `self.log(...)` 是指什么?(猜测用于传给torchmetric, 类似于tf的Metric) 似乎最终调用的是`pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log()`
  * 此函数体内涉及到`lightning_utilities.core.apply_func.apply_to_collections`

### Pytorch vs Lightning

* Dataset, DataLoader: 在 Lightning 可以沿用, 或者使用 `LightningDataModule`, 多卡训练时, Dataloader 所需的 DistributedSampler 在 Lightning 中无需手动写
* nn.Module: 在 Lightning 使用 `LightningModule`, 需要提供 `forward`, `training_step`, `configure_optimizers` 方法
* 训练过程:
  * for loop: 在 Lightning 无需自己写 for loop
  * loss backward: 在 Lightning 中可以让 `training_step` 返回 loss, 自动进行 backward, 或者也可以手工写 backward 的执行过程
* DDP, AMP: 在 Lightning 中用 `Trainer` 的初始化参数指定
* 模型加载与保存: 最简单的用法是 Lightning 中用 `Trainer` 自动处理, 高级用法是初始化 `Trainer` 时增加 `pytorch_lightning.callbacks.ModelCheckpoint` 这个 callback, 更复杂的用法是关闭 `Trainer` 的模型保存功能(`enable_checkpointing=False`), 在 `LightningModule` 的 `training_epoch_end` 或者 `validation_epoch_end` 中检测当前的 local\_rank, 只在 local\_rank 为0的进程上做保存模型的工作
* 控制台打印: 可以使用 python 本身的 `logging` 模块进行, 参考[官方文档](https://pytorch-lightning.readthedocs.io/en/stable/common/console_logs.html)

### 使用模板

```python
from pytorch_lightning import LightModule, Trainer
model = MyModule(...)  # MyModule 继承自 LightModule
trainer = Trainer(...)  # max_steps, min_steps 等参数
trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None,ckpt_path=None)
```

### pytorch\_lightning.LightningModule

代码参考自: [pytorch-lightning with huggingface transfomers](https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/text-transformers.html)

```python
class GLUETransformer(LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        task_name: str,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs,
    ):
        super().__init__()

        # 此函数用于保存__init__函数的入参至self.hparams, 含义是收集所有的超参数,推荐在此处使用
        # 如果__init__函数的入参中有torch.nn.Module, 可以设置参数ignore将其忽略
        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
        self.metric = datasets.load_metric(
            "glue", self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        # 返回一个标量版的loss即可, 或者返回一个字典, 字典中有一个键值对为{"loss": loss}
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.hparams.num_labels > 1:
            preds = torch.argmax(logits, axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        return {"loss": val_loss, "preds": preds, "labels": labels}

    def validation_epoch_end(self, outputs):
        if self.hparams.task_name == "mnli":
            for i, output in enumerate(outputs):
                # matched or mismatched
                split = self.hparams.eval_splits[i].split("_")[-1]
                preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x["loss"] for x in output]).mean()
                self.log(f"val_loss_{split}", loss, prog_bar=True)
                split_metrics = {
                    f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
                }
                self.log_dict(split_metrics, prog_bar=True)
            return loss

        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,  # 注意: 此处可以使用trainer的变量
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]
```

#### inference without pytorch-lightning package

为了更好地模块化, 建议采用如下方式组织代码

```python
# 存疑: 官方演示视频中此处继承的是LightningModule
class TorchModule(torch.nn.Module):
    def __init__(self, **model_hparams):
        ...
    def forward(self, input):
        ...
class PLModule(LightningModule)
    def __init__(self, model, **kwargs):
        self.model = model
    def training_step(self, batch, batch_idx):
        ...
    def validation_step(self, batch, batch_idx):
        ...
    def configure_optimizers(self):
        ...
```

#### 控制backward的逻辑

更多详细内容参考[官方文档](https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html)

如果需要控制backward的逻辑, 需要做以下事情

* 在`__init__`中设置`self.automatic_optimization=False`
* 在`training_step`中使用如下API:
  * 使用 `optimizer=self.optimizers()` 获取所有优化器
  * 使用 `optimizer.zero_grad()` 清除梯度
  * 使用 `self.manual_backward(loss)` 而不要使用 `loss.backward()`
  * 使用 `optimizer.step()`

#### 多个optimizer与scheduler

建议 `configure_optimizers` 函数按如下方式返回

```python
(
    {
        "optimizer": optimizer_1,
        "lr_scheduler": {
            "scheduler": scheduler_1,
            "interval": "step",
            "frequency": 1,
        }
    },
    {
        "optimizer": optimizer_2,
        "lr_scheduler": {
            "scheduler": scheduler_2,
            "interval": "step",
            "frequency": 1,
        }
    },
)
```

#### 当前 rank、step、epoch 等

```
def training_step(self, batch, batch_idx):
    self.global_rank
    self.local_rank
    self.current_epoch  # 当前epoch数
    self.global_step  # 全局步数
```

#### fit 函数伪代码(hook 编程)

从Lightning的实现上

* hook 指的是类的一些特定方法, 例如: `on_train_epoch_start`, `on_train_batch_start`。
* callback 指的是含有这些特定方法的类。

具体来说，简易版的大致实现方式如下

备注: 下面的写法仅做示意 hook 编程的大致形式, 并非对齐 `LightningModule` 的真正实现

```python
# 不引入Trainer的时候, LightningModule的简易版实现
class MyLightningModule:
    def __init__(self, callbacks):
        self.callbacks = callbacks
    def call_hook(self, hook_name, *args. **kwargs):
        for callback in self.callbacks:
            if hasattr(callback, hook_name):
                func = getattr(callback, hook_name)
                if callable(func):
                    func(*args. **kwargs)  # 返回值怎么处理?
    def fit(self, loader):
        for batch in loader:
            self.call_hook("on_before_batch", batch)
            self.training_step()  # 这个是 hook 吗?
            self.call_hook("on_after_batch", batch)
    def training_step(self):
        raise NotImplementedError()

class MyCustomModule(MyLightningModule):
    def __init__(self, callbacks):
        super().__init__(self, callbacks)
    def training_step(self):
        # ...

class FirstCallback:
    def on_before_batch(self):
        # ...
    def on_after_batch(self):
        # ...
class SecondCallback:
    def on_before_batch(self):
        # ...

if __name__ == "__main__":
    callbacks = [FirstCallback(), SecondCallback()]
    model = MyLightningModule(callbacks)
    loader = ...
    model.fit(loader)
```

在理解了 hook/callback 之后, 可以参考[官方文档](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#hooks)中真正的实现流程理解 fit 函数的整个过程, 并在 `LightningModule` 的子类中覆盖这些方法或者传入自定义 Callback 类, 摘录如下:

```python
def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    on_fit_start()
    setup("fit")
    configure_optimizers()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")


def fit_loop():
    on_train_epoch_start()

    for batch in train_dataloader():
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end()

        if should_check_val:
            val_loop()
    # end training epoch
    training_epoch_end()

    on_train_epoch_end()


def val_loop():
    on_validation_model_eval()  # calls `model.eval()`
    torch.set_grad_enabled(False)

    on_validation_start()
    on_validation_epoch_start()

    val_outs = []
    for batch_idx, batch in enumerate(val_dataloader()):
        on_validation_batch_start(batch, batch_idx)

        batch = on_before_batch_transfer(batch)
        batch = transfer_batch_to_device(batch)
        batch = on_after_batch_transfer(batch)

        out = validation_step(batch, batch_idx)

        on_validation_batch_end(batch, batch_idx)
        val_outs.append(out)

    validation_epoch_end(val_outs)

    on_validation_epoch_end()
    on_validation_end()

    # set up for train
    on_validation_model_train()  # calls `model.train()`
    torch.set_grad_enabled(True)
```

#### training\_step 出入参

* 入参: 由以下这些 hook 最后的输出得到

  ```python
  for batch in train_dataloader():
      on_train_batch_start()

      on_before_batch_transfer()
      # 将batch中的tensor转移到相关的device上，如果默认的方法不能满足要求, 则可以重载这个函数
      transfer_batch_to_device()
      on_after_batch_transfer()

      training_step()
  ```
* 出参: 返回一个标量版的loss即可, 或者返回一个字典, 字典中有一个键值对为{"loss": loss}

```python
@dataclass
class OutputResult:
    def asdict(self) -> Dict[str, Any]:
        raise NotImplementedError

# src/pytorch_lightning/loops/optimization/optimizer_loop.py
class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
    """Runs over a sequence of optimizers.

    This loop implements what is known in Lightning as Automatic Optimization.
    """

    output_result_cls = ClosureResult
    def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
        ...
        # 注: training_step_output即为training_step的返回结果
        result = self.output_result_cls.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
        ...


@dataclass
class ClosureResult(OutputResult):
    """A container to hold the result of a :class:`Closure` call.

    It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`.

    Attributes:
        closure_loss: The loss with a graph attached.
        loss: A detached copy of the closure loss.
        extra: Any keys other than the loss returned.
    """

    closure_loss: Optional[Tensor]
    loss: Optional[Tensor] = field(init=False, default=None)
    extra: Dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:
        self._clone_loss()

    def _clone_loss(self) -> None:
        if self.closure_loss is not None:
            # the loss will get scaled for amp. avoid any modifications to it
            self.loss = self.closure_loss.detach().clone()

    @classmethod
    def from_training_step_output(
        cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
    ) -> "ClosureResult":
        closure_loss, extra = None, {}

        if isinstance(training_step_output, dict):
            # this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
            closure_loss = training_step_output.get("loss")
            if closure_loss is None:
                raise MisconfigurationException(
                    "In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
                )
            extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
        elif isinstance(training_step_output, Tensor):
            closure_loss = training_step_output
        elif training_step_output is not None:
            raise MisconfigurationException(
                "In automatic optimization, `training_step` must return a Tensor, "
                "a dict, or None (where the step will be skipped)."
            )

        if closure_loss is not None:
            # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
            # note: avoid in-place operation `x /= y` here on purpose
            closure_loss = closure_loss / normalize

        return cls(closure_loss, extra=extra)

    def asdict(self) -> Dict[str, Any]:
        return {"loss": self.loss, **self.extra}
```

#### `training_step`, `validation_step`, `test_step`, `predict_step`

* training\_step: 训练过程, batch中应包含x与y, 被trainer.fit调用
* validation\_step: 验证过程, batch中应包含x与y, 通常的每个epoch结束后被trainer.fit调用
* test\_step: 训练过程, batch中应包含x与y, 在trainer.fit中不被调用, 被trainer.test调用
* predict\_step: 在不定义predict\_step的情况下, trainer.pred会调用model.forward, 否则会调用predict\_step, 因此batch中只包含x即可

```python
def training_step(batch, batch_idx, optimizer_idx, hiddens):
    # 返回可以是三种:(1) loss tensor (2) dict, 但必须包含"loss"这个key (3) None, 跳过此次training_step的过程, 一般用于手动backward

def validation_step(batch, batch_idx, dataloader_idx):
    # 返回可以是(1)tensor (2)dict of tensor (3) None, 跳过此次validation_step

def test_step(batch, batch_idx, dataloader_id):
    # 返回可以是(1)tensor (2)dict of tensor (3) None, 跳过此次test_step

def predict_step(batch, batch_idx, dataloader_id):
    # 返回是Any
```

#### save checkpoint advanced

**方式1** 默认情况下, 会自动保存模型(只保存最新的), 参考[官方文档](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#enable-checkpointing)

> By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch, Checkpoints capture the exact value of all parameters used by a model. To disable automatic checkpointing, set this to False.

```python
# default used by Trainer, saves the most recent model to a single checkpoint after each epoch
trainer = Trainer(enable_checkpointing=True)

# turn off automatic checkpointing
trainer = Trainer(enable_checkpointing=False)
```

备注: 这种情况下默认会保存在`lightening_logs/version_n/checkpoints` 目录中, 且会保存许多其他非模型权重的东西

```python
import torch
d = torch.load("lightening_logs/version_n/checkpoints/xxx.ckpt")
d.keys()
# epoch, global_step, pytorch_lightening_version, state_dict, loops, callbacks, optimizer_states, lr_schedulers
# 其中state_dict为模型的权重
```

**方式2** trainer 中传入的callbacks包含一个 `ModelCheckpoint` 实例, 参考[官方文档](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#enable-checkpointing)。用于指定保存路径，保留多少个checkpoint，是否只保留权重，是否根据某个特定的监控值保存最优模型等

```python
from pytorch_lightning.callbacks import ModelCheckpoint

# Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor="val_loss")

# Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])
```

**方式3** 手写, 在各个hook中加上一些条件进行模型保存

```python
from pytorch_lightning import LightningModule, Trainer
import torch
class MyModel:
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(64, 32)
    def training_epoch_end(self, training_outs):
        if self.current_epoch in [0, 2] and self.local_rank == 0:
            torch.save(self.layer.state_dict(), f"epoch_{self.current_epoch}.pth")
model = MyModel()
trainer = Trainer(max_epochs=4, gpus=2, enable_checkpointing=False)
```

**方式4(不确定): 继承ModelCheckpoint**

#### 怎样控制 log 打印（未解决）

包含以下几个部分：

* `LightningModule.training_step` 的返回值在 `LightningModule.on_epoch_end` 中保存的具体逻辑是？
* 如何控制 Tensorboard 的打印内容与打印时机
* 如何分别控制打印到控制台/保存到日志文件/输出到 Tensorboard 的信息

`LightningModule` 中的 `self.log(...)` 最终调用的是 `trainer._results.log`，而这个对象最终对应的是类似于`trainer.fit_loop.epoch_loop._results`，它是`pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection` 对象。

逻辑大约是`module.log`调用时，在 `_results` 中记录下信息，之后在 `TrainingEpochLoop` 的 `advance`函数中调用`self.trainer._logger_connector.update_train_step_metrics()`实际将 log 写入：具体进一步触发 `LoggerConnector` 中调用 `self.log_metrics(self.trainer._results.metrics(not self._epoch_end_reached)['log'])`，根据是否需要写日志，再触发 TensorBoardLogger 的 `log_metrics` 函数。

* `trainer` 的构造函数 `__init__` 中有一个参数 `log_every_n_steps` 用于控制 `update_train_step_metrics` 函数将数据写入 tensorboard
* 也就是 `module.log()` 最终修改了 `trainer._results`, 然后利用这个在某个时机写入了 tensorboard

备注：由于 `TensorBoardLogger` 中没有保存图像的函数，因此，如果想完全发挥它内部包含的 `torch.utils.tensorboard import SummaryWriter`。可以通过 `TensorBoardLogger.experiment`进行直接实现。

例如（待验证）：

```python
import pytorch_lightning as pl

# 在 LightningModule 中即刻将log写入tensorboard
    def training_step(...):
        for logger in self.trainer.loggers:
            if isinstance(logger, pl.loggers.TensorBoardLogger):
                logger.experiment.add_image('a/ori_image', ori_image, global_idx)
```

### pytorch\_lightning.Trainer

```
trainer = Trainer()
trainer.fit(model)
```

备注：相比于 huggingface transformers 中的 `Trainer` 类，官方文档中鼓励对其使用继承的方法重写一些方法。pytorch-lightning 中的推荐做法是直接使用 `Trainer`，而对 `LightningModule` 进行继承以及方法重写。

### pytorch\_lightning.LightningDataModule

```python
import torch
from pytorch_lightning import LightningDataModule

class MyDataset(torch.utils.data.Dataset):
    ...
    def __getitem__(self, idx):
        return transform(data[idx])

class MyDataModule(LightningDataModule):

    def __init__(
        self,
        model_name_or_path: str,
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs,
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def setup(self, stage: str):
        # 此部分代码在多卡场景下会被每个进程执行, 因此可以设置变量
        # 建议设置self.dataset
        self.datasets = {
            "train": MyDataset(...)
            "val": MyDataset(...)
        }

    def prepare_data(self):
        # 此部分代码仅在rank0上运行, 建议不要设置类的属性
        # 建议做一些数据的转换工作, 例如切分数据至train/val文件夹,将数据tokenizer化保存
        pass

    def train_dataloader(self):
        # 此部分代码在多卡场景下会被每个进程执行
        # 建议dataset在self.setup方法中设定, 此处直接使用torch.utils.data.DataLoader进行包装
        return torch.utils.data.DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset["val"], batch_size=self.train_batch_size, shuffle=False)
```

## 第二部分：`Lightning` 源码阅读

OpenMMLab对pytorch-lightning也有一篇源码解读文章: <https://zhuanlan.zhihu.com/p/389271556>

目标: 理解如下代码片段的执行过程

```python
from pytorch_lightning import LightModule, Trainer
model = MyModule(...)  # MyModule 继承自 LightModule
trainer = Trainer(...)  # max_steps, min_steps 等参数
trainer.fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None,ckpt_path=None)
```

### `LightningModule`

#### 父类

源码中关于 `LightningModule` 类的定义继承自了多个父类, 特别注意它也继承自`torch.nn.Module`。因此需要先对几个父类的代码做个了解

```python
class LightningModule(
    _DeviceDtypeModuleMixin,
    HyperparametersMixin,
    ModelIO,
    ModelHooks,
    DataHooks,
    CheckpointHooks,
    Module,
):
    ...
```

<details>

<summary>HyperparametersMixin</summary>

主要使用的方式如下参考[官网例子](https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/text-transformers.html): 通过调用 `self.save_hyperparameters` 方法, 将所有 `__init__` 的传参保存到`self._hparams`中

```python
class MyModule(LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int,
        learning_rate: float = 2e-5,
        train_batch_size: int = 32,
        **kwargs,
    ):
        super().__init__()
        # MyModule中其他地方可以使用self.hparams.num_labels
        self.save_hyperparameters(ignore=["model_name_or_path"])
model = MyModule("x.pth", 10)
print(model.hparams)  # pytorch_lightning.utilities.parsing.AttributeDict
# {"num_labels": 10, "learning_rate": 2e-5, "train_batch_size": 32}
```

</details>

### `Trainer.__init__`

`Trainer` 类没有父类, 直接继承自 `object`.

第 2 行代码: `trainer = Trainer(...)` 的源码如下:

```python
# src/pytorch_lightning/trainer/trainer.py:Trainer
@_defaults_from_env_vars
def __init__(self, logger=True, ...)  # 共有约50个参数
    ...
```

首先解释一下这个装饰器的作用:

利用 `os.environ.get` 方法获取形如 `PL_TRAINER_{XXX}` 环境变量, 并用环境变量的值取代被装饰的函数(上面的例子中为`Trainer.__init__`函数)中的默认值. 即最终第 2 行代码参数设定的优先顺序为:

```
实参 > 环境变量 > 函数定义中形参的默认值
```

<details>

<summary>装饰器`_defaults_from_env_vars`的具体实现</summary>

```python
# src/pytorch_lightning/utilities/argparse.py
def _defaults_from_env_vars(fn: _T) -> _T:
    @wraps(fn)  # 注: functools.wraps
    def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
        cls = self.__class__  # get the class
        if args:  # in case any args passed move them to kwargs
            # parse only the argument names
            cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
            # convert args to kwargs
            kwargs.update(dict(zip(cls_arg_names, args)))  # 注: 此处的kwargs为实参
        env_variables = vars(parse_env_variables(cls))  # 注: 此处为从环境变量处解析得到的默认值
        # update the kwargs by env variables
        # 注: 这里第2项中的键值对会覆盖第1项的键值对, 因此优先级为实参>环境变量>函数定义中的默认值
        kwargs = dict(list(env_variables.items()) + list(kwargs.items()))

        # all args were already moved to kwargs
        return fn(self, **kwargs)

    return cast(_T, insert_env_defaults)  # 注: typing.cast
```

理解上面的代码所需要的 Python 知识如下:

* `functools.wraps`装饰器的作用
* `vars` 内置函数的作用
* `typing.cast` 函数的作用: 参考[stackoverflow](https://stackoverflow.com/questions/51457563/what-does-typing-cast-do-in-python),此函数只在类型检查时起作用, 而在实际运行时什么都不做, 直接将入参返回

上述代码中进一步调用了如下两段代码

```python
# src/pytorch_lightning/utilities/argparse.py
def parse_env_variables(cls: _ARGPARSE_CLS, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
    """Parse environment arguments if they are defined.

    Examples:

        >>> from pytorch_lightning import Trainer
        >>> parse_env_variables(Trainer)
        Namespace()
        >>> import os
        >>> os.environ["PL_TRAINER_GPUS"] = '42'
        >>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
        >>> parse_env_variables(Trainer)
        Namespace(gpus=42)
        >>> del os.environ["PL_TRAINER_GPUS"]
    """
    cls_arg_defaults = get_init_arguments_and_types(cls)

    env_args = {}
    for arg_name, _, _ in cls_arg_defaults:
        env = template % {"cls_name": cls.__name__.upper(), "cls_argument": arg_name.upper()}
        val = os.environ.get(env)
        if not (val is None or val == ""):
            # todo: specify the possible exception
            with suppress(Exception):  # 注: contextlib.suppress
                # converting to native types like int/float/bool
                val = literal_eval(val)  # 注: ast.literal_eval
            env_args[arg_name] = val
    return Namespace(**env_args)


def get_init_arguments_and_types(cls: _ARGPARSE_CLS) -> List[Tuple[str, Tuple, Any]]:
    r"""Scans the class signature and returns argument names, types and default values.

    Returns:
        List with tuples of 3 values:
        (argument name, set with argument types, argument default value).

    Examples:

        >>> from pytorch_lightning import Trainer
        >>> args = get_init_arguments_and_types(Trainer)

    """
    cls_default_params = inspect.signature(cls).parameters  # 注
    name_type_default = []
    for arg in cls_default_params:
        arg_type = cls_default_params[arg].annotation
        arg_default = cls_default_params[arg].default
        try:
            arg_types = tuple(arg_type.__args__)
        except (AttributeError, TypeError):
            arg_types = (arg_type,)

        name_type_default.append((arg, arg_types, arg_default))

    return name_type_default
```

理解上面的代码所需要的 Python 知识如下:

* 内置模块 `inspect` 相关知识: 此处仅用到 `inspect.signature(callable).parameters`, 用于获取函数或类的`__init__`方法定义中的变量名, 变量类型, 以及默认值
* `contextlib.suppress`: 参考[stackoverflow](https://stackoverflow.com/questions/34566806/why-use-contextlib-suppress-as-opposed-to-try-except-with-pass), 这两种写法基本等价:

  ```python
  # 写法一:
  with contextlib.suppress(ValueError):
      x = int('a')
  # 写法二:
  try:
      x = int('a')
  except ValueError:
      pass
  ```
* 内置模块 `ast.literal_eval`: 此函数接受的参数为一个合法的字符串形式的python数据, 例如:

  ```python
  ast.literal_eval("['a', 'b']")  # 返回列表: ["a", "b"]
  ast.literal_eval("'a'")  # 返回字符串: "a"
  ast.literal_eval("1")  # 返回整数: 1
  ast.literal_eval("1.2")  # 返回浮点数: 1.2
  ast.literal_eval("1+1")  # 报错
  ast.literal_eval("a")  # 报错
  ```

  即功能弱于内置函数`eval`, 官方文档中建议一切能用 `ast.literal_eval` 代替 `eval` 的地方, 都使用 `ast.literal_eval`, 无法替代的情况下, 应该选择其他实现方式, 而不能依赖 `eval`

</details>

接下来进入 `Trainer.__init__` 函数的函数体, **完整**源代码如下:

```python
# src/pytorch_lightning/trainer/trainer.py
@_defaults_from_env_vars
def __init__(self, logger, ....):  # 注: 一共有约50个参数
    super().__init__()
    # 即执行: torch._C._log_api_usage_once("lightning.trainer." + "init")
    # 在环境变量为PYTORCH_API_USAGE_STDERR=1时才打印信息
    Trainer._log_api_event("init")
    # 此处的 log 是本文件的"全局"变量
    # log = logging.getLogger(__name__)
    log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
    # 见说明 `TrainerState`
    self.state = TrainerState()

    # 见说明 `Connector`
    # init connectors
    self._data_connector = DataConnector(self, multiple_trainloader_mode)

    self._accelerator_connector = AcceleratorConnector(
        num_processes=num_processes,
        devices=devices,
        tpu_cores=tpu_cores,
        ipus=ipus,
        accelerator=accelerator,
        strategy=strategy,
        gpus=gpus,
        num_nodes=num_nodes,
        sync_batchnorm=sync_batchnorm,
        benchmark=benchmark,
        replace_sampler_ddp=replace_sampler_ddp,
        deterministic=deterministic,
        auto_select_gpus=auto_select_gpus,
        precision=precision,
        amp_type=amp_backend,
        amp_level=amp_level,
        plugins=plugins,
    )
    self._logger_connector = LoggerConnector(self)
    self._callback_connector = CallbackConnector(self)
    self._checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
    self._signal_connector = SignalConnector(self)
    # 见下面说明 `Tuner`
    self.tuner = Tuner(self)

    # 见下面说明 `Loop`
    fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
    training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
    # 注: 执行的函数体为: fit_loop.epoch_loop=training_epoch_loop
    fit_loop.connect(epoch_loop=training_epoch_loop)

    # default .fit() loop
    self.fit_loop = fit_loop

    # default .validate() loop
    self.validate_loop = EvaluationLoop()

    # default .test() loop
    self.test_loop = EvaluationLoop()

    # default .predict() loop
    self.predict_loop = PredictionLoop()

    # set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
    self._ckpt_path: Optional[str] = None

    # init callbacks
    # Declare attributes to be set in _callback_connector on_trainer_init
    self._callback_connector.on_trainer_init(
        callbacks,
        enable_checkpointing,
        enable_progress_bar,
        default_root_dir,
        enable_model_summary,
        max_time,
        accumulate_grad_batches,
    )

    # hook
    # 见下面说明 `_call_callback_hooks`
    # V1.8版本对这个做了移除处理, 默认的几个Callback都没有这个hook
    self._call_callback_hooks("on_init_start")

    # init data flags
    # 有点诡异, 没有赋值?
    self.check_val_every_n_epoch: int
    self._data_connector.on_trainer_init(
        val_check_interval,
        reload_dataloaders_every_n_epochs,
        check_val_every_n_epoch,
    )

    # gradient clipping
    if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
        raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")

    if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type(
        gradient_clip_algorithm.lower()
    ):
        raise MisconfigurationException(
            f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
            f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
        )

    # gradient norm tracking
    if track_grad_norm != -1 and not (
        (isinstance(track_grad_norm, (int, float)) or track_grad_norm == "inf") and float(track_grad_norm) > 0
    ):
        raise MisconfigurationException(
            f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}."
        )

    self.gradient_clip_val: Union[int, float] = gradient_clip_val
    self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
        GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
    )
    self.track_grad_norm: float = float(track_grad_norm)

    self._detect_anomaly: bool = detect_anomaly
    # 见下面说明 `_setup_on_init`
    self._setup_on_init()

    # configure tuner
    # 见下面说明 `Tuner`
    self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

    # configure profiler
    # 见下面说明 `setup._init_profiler`
    setup._init_profiler(self, profiler)

    # init logger flags
    self._loggers: List[Logger]
    self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)

    # init debugging flags
    self.val_check_interval: Union[int, float]
    self.num_sanity_val_steps: Union[float, int]
    # 见下面说明 `setup._init_debugging_flags`
    setup._init_debugging_flags(
        self,
        limit_train_batches,
        limit_val_batches,
        limit_test_batches,
        limit_predict_batches,
        fast_dev_run,
        overfit_batches,
        val_check_interval,
        num_sanity_val_steps,
    )

    # Callback system
    self._call_callback_hooks("on_init_end")
```

**整体流程简要描述**：

总的来说基本上是一些为`Trainer`的属性赋值的操作

```python
self.state = TrainerState()  # 后续调用fit/test等函数时会对这个self.state进行设置
# 初始化DataConnector,AcceleratorConnector,LoggerConnector,CallbackConnector,CheckpointConnector,SignalConnector, 代码从略, 除了AcceleratorConnector进行了一些实质性的准备工作外(例如DDP的一些诸如：dist.init_process_group的操作，是否实际执行存疑，待后续明确), 其余基本上都只是对属性值进行了一些初始化

# 初始化几个loop，实际上仅根据入参设定了一些参数, 涉及到的几个loop嵌套关系见后文说明
self.fit_loop = fit_loop
self.validate_loop = EvaluationLoop()
self.test_loop = EvaluationLoop()
self.predict_loop = PredictionLoop()

# 主要执行逻辑是依次将如下默认Callback类添加至`Trainer.callbacks`中, 然后将这些callback按照类型进行重排序，先后顺序为：tuner_callbacks(BatchSizeFinder),other_callbacks, checkpoint_callbacks(ModelCheckpoint)。
self._callback_connector.on_trainer_init(...)

# 依次调用所有callback的"on_init_start" hook, lightning v1.8对这一过程做了移除, 可参考关于Trainer.fit的代码解析
# self._call_callback_hooks("on_init_start")

# 作用是根据入参设定trainer的几个属性
self._data_connector.on_trainer_init(...)

# 设定一些属性, 并在主进程上打印一些GPU/TPU是否可用, 是否使用的日志
self._setup_on_init()

# 作用是根据入参设定trainer的几个属性, self.auto_lr_find = auto_lr_find
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

# 初始化self.profiler, 默认为初始化一个PassThroughProfiler(profiler=None)
setup._init_profiler(self, profiler)

# 设定trainer.loggers(列表): 如果参数为logger默认值True,则创建TensorBoardLogger, 否则按照logger设定
self._logger_connector.on_trainer_init(...)

# 设定一些debug用的参数, 作用未知?
setup._init_debugging_flags(...)
```

需要细致展开的部分如下：

<details>

<summary>TrainerState</summary>

```python
# from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
self.state = TrainerState()
```

涉及的源代码如下：

备注：`LightningEnum` 实际上继承自 python 原生的 `Enum`，增加了一个 `from_str` 方法，并允许它直接与字符串比较（`__eq__`函数），如果与枚举状态所对应的字符串相同，则返回 `True`。

```python
# src/lightning_lite/utilities/enums.py
from pytorch_lightning.utilities import LightningEnum
@dataclass
class TrainerState:
    """Dataclass to encapsulate the current :class:`~pytorch_lightning.trainer.trainer.Trainer` state."""
    # trainer的运行状态: "initializing", "running", "finished", "interrupted"
    status: TrainerStatus = TrainerStatus.INITIALIZING
    # "fit", "validate", "test", "predict", "tune"
    # 与trainer.fit/validate/test/predict/tune直接绑定
    fn: Optional[TrainerFn] = None
    # "sanity_check", "train", "validate", "test", "predict", "tune"
    # trainer.fit函数内的具体状态, 会依次变为: "sanity_check", "train", "validate"
    stage: Optional[RunningStage] = None

    # detect the fault tolerant flag
    # 这个不确定是用来做什么的
    _fault_tolerant_mode: _FaultTolerantMode = field(default_factory=_FaultTolerantMode.detect_current_mode)

    @property
    def finished(self) -> bool:
        return self.status == TrainerStatus.FINISHED

    @property
    def stopped(self) -> bool:
        return self.status.stopped
```

以下是更为具体的源代码

```python
class TrainerStatus(LightningEnum):
    """Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`"""

    INITIALIZING = "initializing"  # trainer creation
    RUNNING = "running"
    FINISHED = "finished"
    INTERRUPTED = "interrupted"

    @property
    def stopped(self) -> bool:
        return self in (self.FINISHED, self.INTERRUPTED)

class TrainerFn(LightningEnum):
    """
    Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
    such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
    :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
    """

    FITTING = "fit"
    VALIDATING = "validate"
    TESTING = "test"
    PREDICTING = "predict"
    TUNING = "tune"

    @property
    def _setup_fn(self) -> "TrainerFn":
        """``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders.

        This is used for the ``setup()`` and ``teardown()`` hooks
        """
        return TrainerFn.FITTING if self == TrainerFn.TUNING else self

class RunningStage(LightningEnum):
    """Enum for the current running stage.

    This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
    More than one running stage value can be set while a :class:`TrainerFn` is running:

        - ``TrainerFn.FITTING`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}``
        - ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
        - ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
        - ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
        - ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
    """

    TRAINING = "train"
    SANITY_CHECKING = "sanity_check"
    VALIDATING = "validate"
    TESTING = "test"
    PREDICTING = "predict"
    TUNING = "tune"

    @property
    def evaluating(self) -> bool:
        return self in (self.VALIDATING, self.TESTING)

    @property
    def dataloader_prefix(self) -> Optional[str]:
        if self in (self.SANITY_CHECKING, self.TUNING):
            return None
        if self == self.VALIDATING:
            return "val"
        return self.value
```

</details>

<details>

<summary>Connector(DataConnector,AcceleratorConnector,LoggerConnector,CallbackConnector,CheckpointConnector,SignalConnector)</summary>

`Trainer.__init__` 函数依次进行了 `DataConnector`, `AcceleratorConnector`,`LoggerConnector`, `CallbackConnector`, `CheckpointConnector`, `SignalConnector`几个的初始化

推测: 这种`XXXConnector`类的作用基本上就是给`Trainer`添加一些属性, 不知道为啥不直接写在`Trainer`的内部(也许是写在Trainer内部, Trainer类的定义会变得很冗长?看源码长度`pytorch_ligtening/trainer/trainer.py`本身已有2000多行, 如果这些Connector也写在Trainer里,估计会更长)

这里几个 `Connector` 的 `__init__` 函数的初始化基本上只是设定 `self.trainer=trainer`，以及初始化一些状态，并无太多需要说明之处。唯一的例外是`AcceleratorConnector`做了许多工作（此处从略）。

</details>

<details>

<summary>Loop</summary>

```python
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
# 注: 执行的函数体为: fit_loop.epoch_loop=training_epoch_loop
fit_loop.connect(epoch_loop=training_epoch_loop)
```

为了更好地看出上面三行代码的执行逻辑，将其展开为（`|-`引导内部的调用顺序）如下

备注：一些简单的操作例如：`self.min_steps=min_steps` 被省略了

```python
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
|-self.epoch_loop = TrainingEpochLoop()
| |-self.batch_progress = BatchProgress()
| |-self.scheduler_progress = SchedulerProgress()
| |-self.batch_loop = TrainingBatchLoop()
| | |-# 内部保存一个长度为20的memory
| | |-self.accumulated_loss = TensorRunningAccum(window_length=20)
| | |-self.running_loss = TensorRunningAccum(window_length=20)
| | |-self.optimizer_loop = OptimizerLoop()
| | | |-self.optim_progress: OptimizationProgress = OptimizationProgress()
| | |-self.manual_loop = ManualOptimization()
| |   |-self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
| |-self.val_loop = loops.EvaluationLoop(verbose=False)
|   |-self.epoch_loop = EvaluationEpochLoop()
|     |-self.batch_progress = BatchProgress()
|-self.epoch_progress = Progress()

training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
# fit_loop.epoch_loop=training_epoch_loop
fit_loop.connect(epoch_loop=training_epoch_loop)
self.fit_loop = fit_loop
self.validate_loop = EvaluationLoop()
self.test_loop = EvaluationLoop()
self.predict_loop = PredictionLoop()
|-self.epoch_loop = PredictionEpochLoop()
  |-self.batch_progress = Progress()
```

这里代码粗看上有些诡异:

* `FitLoop` 中包含了一个`TrainingEpochLoop`, 继续而这个 `TrainingEpochLoop`包含`TrainingBatchLoop`和一个`EvaluationLoop`，似乎并不在一个层级上

  ```
  TrainingEpochLoop -- TrainingBatchLoop -- OptimizerLoop
                    \_ EvaluationLoop -- EvaluationEpochLoop(没有EvaluationBatchLoop?)
  ```
* `*Progress` 的定义均在 `src/pytorch_lightning/trainer/progress.py` 文件中, 主要作用是循环时记录下标和一些状态, 即记录 `for i, x in enumerate(x_list)` 中的 `i` 与 `x`。其细致的源码分析如下：

</details>

<details>

<summary>CallbackConnector.on_trainer_init</summary>

```python
# trainer.__init__函数内部调用了这个函数
self._callback_connector.on_trainer_init(...)
```

**主要执行逻辑是依次将如下默认Callback类添加至`Trainer.callbacks`中, 然后将这些callback按照类型进行重排序，先后顺序为：tuner\_callbacks(BatchSizeFinder),other\_callbacks, checkpoint\_callbacks(ModelCheckpoint)。**

这些默认的Callback及及相应控制的`Trainer.__init__`函数的入参默认值以及所包含的hook列举如下:

* **pytorch\_lightning.callbacks.model\_checkpoint.ModelCheckpoint**（默认有, enable\_checkpointing=True）: on\_train\_batch\_end, on\_train\_epoch\_end, on\_train\_start, on\_validation\_end
* **pytorch\_lightning.callbacks.timer.Timer**（默认无, max\_time=None）: on\_fit\_start, on\_test\_end, on\_test\_start, on\_train\_batch\_end, on\_train\_end, on\_train\_epoch\_end, on\_train\_start, on\_validation\_end, on\_validation\_start
* **pytorch\_lightning.callbacks.progress.tqdm\_progress.TQDMProgressBar**（默认有, enable\_progress\_bar=True）: on\_predict\_batch\_end, on\_predict\_batch\_start, on\_predict\_end, on\_predict\_start, on\_sanity\_check\_end, on\_sanity\_check\_start, on\_test\_batch\_end, on\_test\_batch\_start, on\_test\_end, on\_test\_start, on\_train\_batch\_end, on\_train\_end, on\_train\_epoch\_end, on\_train\_epoch\_start, on\_train\_start, on\_validation\_batch\_end, on\_validation\_batch\_start, on\_validation\_end, on\_validation\_start
* **pytorch\_lightning.callbacks.rich\_model\_summary.RichModelSummary**（默认无, 除非enable\_progress\_bar=False且手动传入这个callback）: on\_fit\_start
* **pytorch\_lightning.callbacks.model\_summary.ModelSummary**（默认有, enable\_model\_summary）: on\_fit\_start
* **pytorch\_lightning.callbacks.gradient\_accumulation\_scheduler.GradientAccumulationScheduler**（默认无, accumulate\_grad\_batches=None）: on\_train\_epoch\_start
* **pytorch\_lightning.callbacks.fault\_tolerance.\_FaultToleranceCheckpoint**（必然有）: on\_exception
* **pytorch\_lightning.callbacks.batch\_size\_finder.BatchSizeFinder**（默认无, auto\_lr\_find=False, auto\_scale\_batch\_size=False）: on\_fit\_start, on\_predict\_start, on\_test\_start, on\_validation\_start

**完整**源代码如下：

```python
class CallbackConnector:
    ...
    def on_trainer_init(
        self,
        callbacks: Optional[Union[List[Callback], Callback]],
        enable_checkpointing: bool,
        enable_progress_bar: bool,
        default_root_dir: Optional[str],
        enable_model_summary: bool,
        max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
        accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
    ) -> None:
        # init folder paths for checkpoint + weights save callbacks
        self.trainer._default_root_dir = default_root_dir or os.getcwd()

        # init callbacks
        if isinstance(callbacks, Callback):
            callbacks = [callbacks]
        self.trainer.callbacks = callbacks or []

        # configure checkpoint callback
        # pass through the required args to figure out defaults
        # 注: self.trainer.callbacks增加pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
        self._configure_checkpoint_callbacks(enable_checkpointing)

        # configure the timer callback.
        # responsible to stop the training when max_time is reached.
        # 注: max_time为程序最大运行时间,如果设置,则增加pytorch_lightning.callbacks.timer.Timer
        self._configure_timer_callback(max_time)

        # init progress bar
        # 注: 增加pytorch_lightning.callbacks.progress.tqdm_progress.TQDMProgressBar
        self._configure_progress_bar(enable_progress_bar)

        # configure the ModelSummary callback
        # 注: 增加pytorch_lightning.callbacks.model_summary.ModelSummary
        self._configure_model_summary_callback(enable_model_summary)

        # accumulated grads
        # 注: 如果设置了梯度累累积, 则设置pytorch_lightning.callback.gradient_accumulation_scheduler.GradientAccumulationScheduler
        self._configure_accumulated_gradients(accumulate_grad_batches)

        # 注: ...
        if self.trainer.state._fault_tolerant_mode.is_enabled:
            self._configure_fault_tolerance_callbacks()
        
        # 注: 一般是空列表, 可以从Entrypoint中加入callback, 见官网说明
        # https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#entry-points
        self.trainer.callbacks.extend(_configure_external_callbacks())

        # push all model checkpoint callbacks to the end
        # it is important that these are the last callbacks to run
        self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks) # 注: 见下面的函数定义, Checkpoint被排放至最后

    @staticmethod
    def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
        tuner_callbacks: List[Callback] = []
        other_callbacks: List[Callback] = []
        checkpoint_callbacks: List[Callback] = []

        for cb in callbacks:
            if isinstance(cb, BatchSizeFinder):
                tuner_callbacks.append(cb)
            elif isinstance(cb, Checkpoint):
                checkpoint_callbacks.append(cb)
            else:
                other_callbacks.append(cb)

        return tuner_callbacks + other_callbacks + checkpoint_callbacks
```

</details>

<details>

<summary>DataConnector.on_trainer_init</summary>

```python
# trainer.__init__函数内部调用了这个函数
self._data_connector.on_trainer_init(...)
```

**作用是根据入参设定trainer的几个属性**

**除去异常处理**的源代码如下

```python
def on_trainer_init(
        self,
        val_check_interval: Optional[Union[int, float]],
        reload_dataloaders_every_n_epochs: int,
        check_val_every_n_epoch: Optional[int],
    ) -> None:
        self.trainer.datamodule = None
        self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
        self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
        self.trainer._is_data_prepared = False
```

</details>

<details>

<summary>Tuner</summary>

Tuner的主要作用是自动尝试学习率与显存大小, 在\`Trainer.\_\_init\_\_\`函数中仅设定参数, 运行逻辑在 \`Trainer.fit\`函数中

</details>

### `Trainer.fit`

**完整**源代码如下：

```python

def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
    r"""
    Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
    as all errors should funnel through them

    Args:
        trainer_fn: one of (fit, validate, test, predict)
        *args: positional arguments to be passed to the `trainer_fn`
        **kwargs: keyword arguments to be passed to `trainer_fn`
    """
    try:
        if trainer.strategy.launcher is not None:
            return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
        else:
            return trainer_fn(*args, **kwargs)

    except _TunerExitException:
        trainer._call_teardown_hook()
        trainer._teardown()
        trainer.state.status = TrainerStatus.FINISHED
        trainer.state.stage = None

    # TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
    except KeyboardInterrupt as exception:
        rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
        # user could press Ctrl+c many times... only shutdown once
        if not trainer.interrupted:
            trainer.state.status = TrainerStatus.INTERRUPTED
            trainer._call_callback_hooks("on_exception", exception)
            for logger in trainer.loggers:
                logger.finalize("failed")
    except BaseException as exception:
        trainer.state.status = TrainerStatus.INTERRUPTED
        if distributed_available() and trainer.world_size > 1:
            # try syncing remaining processes, kill otherwise
            trainer.strategy.reconciliate_processes(traceback.format_exc())
        trainer._call_callback_hooks("on_exception", exception)
        for logger in trainer.loggers:
            logger.finalize("failed")
        trainer._teardown()
        # teardown might access the stage so we reset it after
        trainer.state.stage = None
        raise

class Trainer:
    def fit(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        r"""
        Runs the full optimization routine.

        Args:
            model: Model to fit.

            train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
                :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
                In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.

            val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.

            ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special
                keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised.
                If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.

            datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
        """
        if not isinstance(model, pl.LightningModule):
            raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
        # self.strategy.lightning_module即是_lightning_module
        self.strategy._lightning_module = model
        call._call_and_handle_interrupt(
            self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
        )

    # 实际上执行发生在这, 但可能被self.strategy包裹
    def _fit_impl(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        Trainer._log_api_event("fit")
        log.detail(f"{self.__class__.__name__}: trainer fit stage")

        self.state.fn = TrainerFn.FITTING
        self.state.status = TrainerStatus.RUNNING
        self.training = True

        # if a datamodule comes in as the second arg, then fix it for the user
        if isinstance(train_dataloaders, LightningDataModule):
            datamodule = train_dataloaders
            train_dataloaders = None
        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
        if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
            raise MisconfigurationException(
                "You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`"
            )

        # links data to the trainer
        self._data_connector.attach_data(
            model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
        )

        # TODO: ckpt_path only in v2.0
        ckpt_path = ckpt_path or self.resume_from_checkpoint
        self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
            self.state.fn,
            ckpt_path,  # type: ignore[arg-type]
            model_provided=True,
            model_connected=self.lightning_module is not None,
        )
        self._run(model, ckpt_path=self.ckpt_path)

        assert self.state.stopped
        self.training = False
        return
```

Trainer.\_call\_callback\_hooks(hook\_name)

`Trainer.fit` 方法中也会多次调用`Trainer._call_callback_hooks`方法, 其**完整**源代码如下:

```python
class Trainer:
    def _call_callback_hooks(
        self,
        hook_name: str,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        log.debug(f"{self.__class__.__name__}: calling callback hook: {hook_name}")
        # TODO: remove if block in v1.8
        if hook_name in ("on_init_start", "on_init_end"):
            # these `Callback` hooks are the only ones that do not take a lightning module.
            # we also don't profile bc profiler hasn't been set yet
            for callback in self.callbacks:
                fn = getattr(callback, hook_name)
                if callable(fn):
                    fn(self, *args, **kwargs)
            return
        # 注: self.lightning_module的定义见如下注解
        pl_module = self.lightning_module
        if pl_module:
            prev_fx_name = pl_module._current_fx_name
            pl_module._current_fx_name = hook_name

        for callback in self.callbacks:
            fn = getattr(callback, hook_name)
            if callable(fn):
                with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
                    fn(self, self.lightning_module, *args, **kwargs)

        if pl_module:
            # restore current_fx when nested context
            pl_module._current_fx_name = prev_fx_name
```

注: 为何要对`"on_init_start", "on_init_end"`这两个做单独的处理? 因为其他的`hook_name`都在`Trainer.fit`方法内部被调用,`Trainer.fit`方法的源代码如下:

```python
class Trainer:
    def fit(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        if not isinstance(model, pl.LightningModule):
            raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
        self.strategy._lightning_module = model
        call._call_and_handle_interrupt(
            self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
        )
    @property
    def lightning_module(self) -> "pl.LightningModule":
        # TODO: this is actually an optional return
        return self.strategy.lightning_module

# src/pytorch_lightning/strategies/strategy.py
class Strategy(ABC):
    @property
    def lightning_module(self) -> Optional["pl.LightningModule"]:
        """Returns the pure LightningModule without potential wrappers."""
        return self._lightning_module
```

## 其他补充

### apply\_to\_collection

`lightning_utilities` 中包含一些递归应用函数(具体实现可以参看源码)

```python
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
data = [[1, 1.2], 2, 2.3, None]
data2 = [[1, 3], 3., 2, None]
# 这个函数也适用于字典和 dataclass, lambda 函数仅应用于指定的类型: int
apply_to_collection(data, int, lambda x: x * 2, include_none=False)
# [[2, 1.2], 4, 2.3]
apply_to_collections(data, data2, int, lambda x, y: x+y)
# [[2, 1.2], 5.0, 2.3, None]
```

由于 `pytorch-lightning` 很多地方将这两个函数用于 `dataclass` 上，因此补充一个示例：

```python
from dataclasses import dataclass
@dataclass
class A:
    a: int
    b: float
    c: str
    d: list
x = A(a=2, b=2.3, c="str", d=[1, 2, 3.9])
apply_to_collection(x, int, lambda x: x * 2)
# A(a=4, b=2.3, c='str', d=[2, 4, 3.9])
```
