huggingface

huggingface 的几个重要项目

  • transformers

  • datasets

  • tokenizers

  • accelerate

  • huggingface-hub

  • evaluate

  • gradio: 一个适用于AI模型做demo的简易前端

transformers 包

整体代码结构

构建模型主要的基类如下

  • PreTrainedModel: 模型

  • PretrainedConfig: 配置

  • PreTrainedTokenizerBase: tokenizer

在以上三个类之上, Pipeline 基类用于组合这三个类.

另外, 还有些小东西: ModelOutput, 是模型输出的结果的基类.

transformers 的代码总体遵循的设计哲学是不强调代码复用, 比如没有一个 attention.py 文件中实现所有的注意力机制, 与之相对应的是将所有的模型基本上写在三个文件里, 例如在 transformers/models/bart 文件夹里与 pytorch 有关的代码文件如下:

使用小技巧

实例化一个随机权重的模型

关于 mask (待研究, 1和0是否有统一的语义约定)

PreTrainedModel

使用

源码解析

transformers 代码中的带有 from_pretrained 的类都继承自 PreTrainedModel, 其具体继承关系如下:

example: BartForConditionalGeneration

具体到上面的例子中:

ModelOutput

Pipeline

使用

源码解析

framework 取值为 tf 或者 pt, 代表 tensorflowpytorch. 一般用于指示代码的输出为 tf.tensortorch.tensor.

PreTrainedTokenizerBase

继承关系如下图所示:

其中 Fast 版本的 Tokenizer 依赖于 huggingface tokenizers 库中的实现, 而普通版本的 Tokenizer 是 huggingface transformers 库中纯 Python 的实现。

在查阅网上不同的资料的过程中,会有几个疑问:

  • 查询 tokenizer 词表的基本信息例如词表、特殊 token 的方法有哪些?

  • 往往会发现许多关于 Tokenizer 相似的方法调用,例如下面这些,但它们之间的关联/区别是什么?

  • 添加 token 的方法有哪些,区别是什么?

  • 怎么训练得到一个 tokenizer

BertTokenizer 为例,向上追溯至 SpecialTokensMixinPretrainedTokenizerBasePretrainedTokenizer 中的一些方法,来回答上述的这些问题:

针对前面几个问题,相关的方法在继承关系中实质上的定义位置如下图所示:

获取 tokenizer 的基本信息

几个相似的方法

备注:

  • __call__ 方法实质上根据输入是文本列表或文本,分别调用了 batch_encode_plusencode_plus 方法

  • encode_plus 方法实质上是依次调用了 tokenizeconvert_tokens_to_idsprepare_for_model 等方法

  • batch_encode_plus 方法实质上是依次调用了 tokenizeconvert_tokens_to_idsprepare_for_modelpad 等方法

  • encode 方法实质上是调用了 encode_plus 方法,然后只取出 "input_ids" 作为返回值

  • decode 方法实质上是依次调用 convert_ids_to_tokensconvert_tokens_to_string 等方法将整数序列转换为文本

  • batch_decode 方法实质上是对输入使用 decode 方法进行列表推导式

添加token的方法及注意事项

添加 token 的方法来源于 SpecialTokensMixin 中的 add_tokensadd_special_tokens,而 add_special_tokens 方法最终会使用到 add_tokens 方法。在使用上,在增加了 token 后,模型侧需要将 embedding 进行 resize,最常见的做法如下:

备注:针对 BertTokenizer 而言,可以使用如下技巧避免对 model 进行改动

怎么训练得到一个 tokenizer

参考资料(待续):

  • huggingface tokenizer 官方文档:https://huggingface.co/docs/tokenizers/index

Trainer

一个完整的例子可以参考 transformers GitHub 源码 examples/pytorch/summarization/run_summarization.py

使用方式如下:

其中training_args以如下方式获取到:

  • Seq2SeqTrainingArguments 继承自 transformers.TrainingArguments(被dataclass装饰),只是一个自定义的“结构体”

  • HfArgumentParser 继承自 argparse.ArgumentParserHfArgumentParser只是在父类的基础上增加了几个方法:parse_json_fileparse_args_into_dataclasses

  • transformers.Seq2SeqTrainer继承自transformers.TrainerSeq2SeqTrainer只是在父类的基础上覆盖了少量的几个方法:它的主体逻辑例如配置多卡训练,整体循环迭代等过程继承自transformers.Trainer,仅覆盖一些training_step中的关键步骤。

Trainer.train的循环体为Trainer.training_step

Seq2SeqTrainer继承自Trainer, 只重载了evaluate,predict,prediction_step 这几个方法

关于 transformers.Trainer

  • Trainer.__init__函数中也允许传入一些callback, 与pytorch-lightning类似, 但hook会更少一些

关于 HfArgumentParser 的一个小示例

Trainer的扩展方式有两种:

  • 增加Callback,但作用有限,按官方的说法callback不影响训练流程

  • 集成Trainer类,重写一些方式例如:compute_loss

TrainerControl, TrainerState, CallbackHandler, TrainerCallback

Trainer 中包含:

  • TrainerControl: 一些是否需要保存,是否需要记录日志的标志

  • TrainerState: 记录当前的训练轮数等,注意log_history是历史的日志记录列表

  • CallbackHandler: for循环各个callback进行调用

总的来说, huggingface transformers 库的 Trainer 写得不是太好, 不利于扩展,但怎么结合 pytorch-lightning 使用 huggingface transformers 库的模型: lightning-transformers

如何增加 Tensorboard 的打印信息

首先看一下 TensorBoardCallback 的实现

接下来看Trainer中跟 TensorBoardCallback 相关的代码

因此,自定义tensorboard的输出内容(在不自定义子类重写self.train方法的前提下):

  • 每隔100个batch,输出训练集的损失:无法做到,原因是 _maybe_log_save_evaluate 无法传递当前batch的数据信息,因此训练集的信息很难记录在日志中

  • 每个100个batch,输出验证集的损失:可以增加一个Callback,在隔100个batch时,将self.control.should_evaluate设置为True

  • 输出验证集的其他信息,例如计算准确率,召回率等多个指标时:自定义一个 CustomTrainer 继承自 Trainer,重载 self.evaluate 方法,并在这个重载方法内部使用 self.log 方法记录到日志中

终极解决方案:自定义子类重写Trainer.train方法,在必要的地方增加逻辑进行日志记录。但self.train方法的代码过于冗长(大约400行代码),基本上这种做法需要将原本的 train 方法抄录大部分。因此,使用 Trainer 不太能随心所欲地增加日志打印逻辑。

模型保存相关

训练时的保存总入口在 trainer._save_checkpoint 函数处,主要保存以下内容:

离线使用数据集、metric、模型文件

运行示例:官方示例

此脚本使用 transformers 包加载模型,使用 datasets 加载数据集以及 metric。

  • 模型的离线下载:去 huggingface 搜索并下载, 并在 from_pretrained 函数参数替换为本地路径

  • 数据离线下载:去 huggingface 搜索并下载, 并在 load_dataset 函数参数替换为本地路径。

    • 备注:在上面这个例子中,下载的是 glue 数据集下的 mrpc 数据,因此搜索下载好 glue 数据集后,还需要进一步根据 data_infos.jsonglue.py 内的内容下载 mrpc 数据文件

    注意如果按上述方式组织文件,需要做几项修改:

  • metric离线下载: 在有网环境下使用 load_metric 函数,默认缓存目录为 ~/.cache/huggingface/modules/datasets_modules/metrics/glue/91f3cfc5498873918ecf119dbf806fb10815786c84f41b85a5d3c47c1519b343。只需要将此目录下的文件拷贝出来,在无网环境下将 load_metric 函数参数替换为本地路径。

datasets

datasets.load_dataset

datasets.load_dataset用于加载数据集, 适用于如下情况:

  • huggingface hub 维护的数据集, 执行逻辑为下载数据集(有可能会去找到该仓库的同名下载与数据预处理脚本),然后缓存至 ~/.cache/huggingface/datasets 目录(默认缓存为.arrow格式), 最后返回数据集

  • 本地数据集情形下,依然会缓存至 ~/.cache/huggingface/datasets 目录,然后返回数据集

  • 如果本地已缓存则直接读缓存,详情参考

输出结果

备注:

  • 输出结果里:Downloading and preparing dataset及以下的内容的逻辑发生在datasets.builder:DatasetBuilder.download_and_prepare函数内

load_dataset 函数的全部参数如下(没有按照实际的参数列表排列):

个人觉得常用的

  • path, name, split: 参考官方文档, path 一般是huggingface hub的仓库名, name 在官方文档中被称为 dataset configuration, 一般是指一个数据集的几个子数据集, split 一般取值为 "train", "test" 等

  • data_dir, data_files: 这两个参数一般适用 path="csv","json" 等

  • cache_dir, keep_in_memory, storage_options: cache_dir 对应于默认的缓存目录 HF_DATASETS_CACH=~/.cache/huggingface/datasets, keep_in_memory 表示不使用缓存, storage_options 目前还不清楚使用场景

  • streaming: 流式下载

  • num_proc: 多进程处理

不常用的

  • features: 不清楚

  • download_config, download_mode, verification_mode, ignore_verification: 与下载相关的, 不清楚

  • save_infos: 不清楚

  • revision, use_auth_token: 与下载版本及下载权限相关

  • task: 不清楚

  • **config_kwargs: 不清楚

缓存数据文件手动读取

缓存目录

huggingface所有项目的默认缓存目录为~/.cache/huggingface

仅就 datasets 模块而言, 缓存的实际内容为【某个数据集使用特定的预处理脚本处理后最终得到的数据文件】,而这些【数据文件】默认以 .arrow 的方式进行缓存。

根据需求不同,对 datasets.load_dataset 的参数有不同的设定

原始文件位置
预处理方法
做法
传参

本地

json/csv格式默认的读取方式

path='csv', data_files='a.csv'

本地

自定义

编写预处理脚本得到datasets.arrow_dataset.Dataset

path='/path/to/script.py'

Huggingface Hub的datasets中

Hub仓库中的下载以及预处理方式

path='username/dataname'

Huggingface Hub的datasets中

自定义预处理方式

编写预处理脚本得到datasets.arrow_dataset.Dataset:方式一、参考Hub仓库中的默认预处理方式,自己编写预处理脚本,这种方法编写的脚本里应包含下载数据的过程;方式二、如有网络问题也可以预先将原始数据下载下来后再针对本地文件编写预处理脚本

path='/path/to/script.py'

Dataset 变换

具体使用参考官方文档

  • map: 对每条数据进行变换

  • filterselect: 挑选数据

一些不理解的代码

第二次执行时会从缓存中读取

tokenizers 包

tokenizers 包在安装 transformers 包时会自动进行安装,在 transformers 包中如何被使用需要进一步研究。

huggingface-hub 包

huggingface-hub 包在安装 transformersdatasets 包时会自动进行安装。前面在 transformers 包与 datasets 包中已简单涉及了许多关于 huggingface 缓存目录的介绍,此处更加清晰地进行介绍:

首先理清一下 huggingface 各个模块关于缓存目录的设置:

  • huggingface-hub 包的默认缓存目录为 HUGGINGFACE_HUB_CACHE=~/.cache/huggingface/hub,其本质是对 git 的一层封装。

  • transformers 包的默认缓存目录为 TRANSFORMERS_CACHE=~/.cache/huggingface/hub(与huggingface-hub一致,并且本质上是直接复用了huggingface-hub的缓存方式,即 blobsrefssnapshots 的方式)

  • datasets 包的默认缓存目录为:HF_DATASETS_CACHE=~/.cache/huggingface/datasets(与huggingface-hub不一致,其本质上是建立了自己的一套缓存数据集的方式,即采用 arrow 格式对数据进行缓存,从而加速数据的加载速度,提升训练效率),另外,使用 datasets.load_dataset 时会将需要的脚本缓存至 ~/.cache/huggingface/modules/datasets_modules 目录

  • evaluate 包设定了如下一些默认缓存路径:

    • HF_METRICS_CACHE=~/.cache/huggingface/metrics

    • HF_EVALUATE_CACHE=~/.cache/huggingface/evaluate

    • HF_MODULES_CACHE=~/.cache/huggingface/modules/evaluate_modules

  • diffusers 包的默认缓存目录为:DIFFUSERS_CACHE=~/.cache/huggingface/diffusers,而需要的脚本缓存目录设定在 ~/.cache/huggingface/modules/diffusers_modules 目录

从上面可以看出,huggingface-hub 包作为 huggingface 所有项目的“基础建设”,各个下游项目会根据需要决定是否完全复用这一“基础建设”。以下是一些具体的例子:

基本上都可以通过 huggingface-hub 的接口将 datasetsmodelsspaces 下载到本地,然后各个下游的包例如:transformersevaluatedatasetsdiffusers 中加载模型/数据集/脚本的函数中传入本地路径即可。

关于 huggingface-hub 缓存目录的官方文档

从 hub 下载文件的主要接口是 hf_hub_downloadsnapshot_download,参考官方文档即可

怎样确认文件下载正确

以下载 bert-base-uncased/pytorch_model.bin 文件为例

检验本地下载的数据是否与上面的信息一致

accelerate 包

accelerate 在安装 transformers 包时不会进行安装

safetensors

pytorch 的 torch.savetorch.load 底层使用了 pickle, 被认为是不安全的格式 (假设你打开这个文件, 那么它就有可能执行任意代码: Arbitrary Code Execution, ACE: people can do whatever they want with your machine), 具体解释可以参考: https://github.com/huggingface/safetensors/discussions/111

具体的API参见官方文档即可, 这里仅对存储格式做探究

结合 pytorch-lightning 使用 transformers 训练

源码

依赖的一些其他三方库学习

  • filelock: 文件锁?安全读写文件时有用?

  • pyarrow: datasets 底层依赖的存储方式

Last updated

Was this helpful?