PyTorch - 3.5 简体中文

Vitis AI 用户指南 (UG1414)

Document ID
UG1414
Release Date
2023-09-28
Version
3.5 简体中文

量化 API

wego_torch.quantize(
module: torch.nn.Module, 
input_shapes: Sequence[Sequence], 
dataloader: Iterable, 
calibrator: Callable[[torch.nn.Module, Any, int, torch.device], None], export_dataloader: Iterable = None, 
device: torch.device = torch.device("cpu"), 
output_dir: str = "quantize_result", 
bitwidth: int = None, 
quant_config_file: Optional[str] = None, 
*args, **kwargs) -> torch.jit.ScriptModule
此函数采用训练后量化 (Post Training Quantization, PTQ) 方法来量化 torch 浮点模型,并返回量化后的 TorchScript 模块用于 WeGO 编译。

如果 PTQ 无法达到所需精度,您可能需要考虑使用量化感知训练 (QAT) 搭配 Vitis AI 量化器 API。如需深入了解量化进程,请参阅 Vitis AI 用户指南中的 量化模型

参数

module
(torch.nn.Module) 输入 PyTorch 浮点模型。
input_shapes
(Sequence[Sequence] 模型的输入形状:列表或元组序列。
dataloader
(Iterable) 校准数据集的数据加载器。它必须可迭代。API 会通过它进行迭代并将返回的值传递给校准器。
calibrator
(Callable) 可调用对象,用于执行批量数据预处理和前传。从数据加载器获取批量数据、按需对其进行预处理,并使用模块对其进行转发。此校准器会在校准和导出阶段调用 N + 1 次。
  • 阶段 1 用于校准。在此阶段,您的数据加载器会进行迭代,并通过模块进行数据传递以收集量化统计数据。校准器会调用 N 次 (N = len(dataloader))。在阶段 1 中,如果未传递可选 export_dataloader,则会保存数据加载器返回的第一个批次,稍后供阶段 2 使用。在此情况下,请确保校准器或迭代副作用未更改第一个批次。
  • 阶段 2 用于导出量化的 TorchScript 模块。在此阶段中,校准器仅调用一次,含一个批次的数据。如果您传入 export_dataloader,那么此 export_dataloader 会迭代,并且仅使用第一个批次。程序处理完第一个批次后,会中断迭代。如果您未传入 export_dataloader,则会使用从阶段 1 保存的第一个批次。
    校准器实参:
    module
    (torch.nn.Module) 用于量化的模块。这是您传入的模块的修改版本,其中具有用于收集数据统计信息的必要机制。您应使用此模块代替原始浮点模型来进行数据前传。
    batch_data
    (Any) 从数据加载器返回的批次数据。
    batch_index
    (int) 批次索引。请按需使用器件 (torch.device) 进行前传。当前仅支持 CPU。
    注释: 用于量化 API 的额外定位实参和关键字实参都将前传给校准器。如需了解更多信息,请参阅 量化模型
export_dataloader
(Iterable) 可选数据加载器,用于导出阶段。默认值为 None。如果值为 None,则使用从阶段 1 保存的第一个批次。
device
(torch.device) 用于校准的器件。当前仅支持 CPU。
output_dir
(str) 临时工作目录。默认值为 quantize_result。部分中间文件保存在此处。
bitwidth
(int) 全局量化位宽。默认值为 8。
quant_config_file
(str) 指向量化器配置文件的路径。默认值为 None。
args
要传递给校准器的额外定位实参。
kwargs
要传递给校准器的额外关键字实参。

如需了解有关如何使用 WeGO 中的即时量化的更多信息,请参阅 WeGO 示例