量化 API
wego_torch.quantize(
module: torch.nn.Module,
input_shapes: Sequence[Sequence],
dataloader: Iterable,
calibrator: Callable[[torch.nn.Module, Any, int, torch.device], None], export_dataloader: Iterable = None,
device: torch.device = torch.device("cpu"),
output_dir: str = "quantize_result",
bitwidth: int = None,
quant_config_file: Optional[str] = None,
*args, **kwargs) -> torch.jit.ScriptModule
此函数采用训练后量化 (Post Training Quantization, PTQ) 方法来量化 torch 浮点模型,并返回量化后的 TorchScript 模块用于 WeGO 编译。
如果 PTQ 无法达到所需精度,您可能需要考虑使用量化感知训练 (QAT) 搭配 Vitis AI 量化器 API。如需深入了解量化进程,请参阅 Vitis AI 用户指南中的 量化模型。
参数
- module
-
(torch.nn.Module)
输入 PyTorch 浮点模型。 - input_shapes
-
(Sequence[Sequence]
模型的输入形状:列表或元组序列。 - dataloader
-
(Iterable)
校准数据集的数据加载器。它必须可迭代。API 会通过它进行迭代并将返回的值传递给校准器。 - calibrator
-
(Callable)
可调用对象,用于执行批量数据预处理和前传。从数据加载器获取批量数据、按需对其进行预处理,并使用模块对其进行转发。此校准器会在校准和导出阶段调用 N + 1 次。- 阶段 1 用于校准。在此阶段,您的数据加载器会进行迭代,并通过模块进行数据传递以收集量化统计数据。校准器会调用 N 次 (N = len(dataloader))。在阶段 1 中,如果未传递可选 export_dataloader,则会保存数据加载器返回的第一个批次,稍后供阶段 2 使用。在此情况下,请确保校准器或迭代副作用未更改第一个批次。
- 阶段 2 用于导出量化的 TorchScript 模块。在此阶段中,校准器仅调用一次,含一个批次的数据。如果您传入 export_dataloader,那么此 export_dataloader 会迭代,并且仅使用第一个批次。程序处理完第一个批次后,会中断迭代。如果您未传入 export_dataloader,则会使用从阶段 1 保存的第一个批次。校准器实参:
- module
-
(torch.nn.Module)
用于量化的模块。这是您传入的模块的修改版本,其中具有用于收集数据统计信息的必要机制。您应使用此模块代替原始浮点模型来进行数据前传。 - batch_data
-
(Any)
从数据加载器返回的批次数据。 - batch_index
-
(int)
批次索引。请按需使用器件 (torch.device) 进行前传。当前仅支持 CPU。
注释: 用于量化 API 的额外定位实参和关键字实参都将前传给校准器。如需了解更多信息,请参阅 量化模型。
- export_dataloader
-
(Iterable)
可选数据加载器,用于导出阶段。默认值为 None。如果值为 None,则使用从阶段 1 保存的第一个批次。 - device
-
(torch.device)
用于校准的器件。当前仅支持 CPU。 - output_dir
-
(str)
临时工作目录。默认值为 quantize_result。部分中间文件保存在此处。 - bitwidth
-
(int)
全局量化位宽。默认值为 8。 - quant_config_file
-
(str)
指向量化器配置文件的路径。默认值为 None。 - args
- 要传递给校准器的额外定位实参。
- kwargs
- 要传递给校准器的额外关键字实参。
如需了解有关如何使用 WeGO 中的即时量化的更多信息,请参阅 WeGO 示例。