Quantization API
wego_torch.quantize(
module: torch.nn.Module,
input_shapes: Sequence[Sequence],
dataloader: Iterable,
calibrator: Callable[[torch.nn.Module, Any, int, torch.device], None], export_dataloader: Iterable = None,
device: torch.device = torch.device("cpu"),
output_dir: str = "quantize_result",
bitwidth: int = None,
quant_config_file: Optional[str] = None,
*args, **kwargs) -> torch.jit.ScriptModule
This function quantizes a torch float model with
Post-Training Quantization (PTQ) method, and a quantized TorchScript Module is returned
for WeGO compilation usage.
If PTQ cannot achieve the required accuracy, you might need to consider using Quantization Aware Training (QAT) with Vitis AI Quantizer API. For an in-depth understanding of the quantization process, see Quantizing the Model in the Vitis AI User Guide.
Parameters
- module
-
(torch.nn.Module)
The input PyTorch float model. - input_shapes
-
(Sequence[Sequence]
Input shapes for the model- a sequence of lists or tuples. - dataloader
-
(Iterable)
Dataloader for calibration dataset. It must be iterable. API iterates through it and passes the returned values to the calibrator. - calibrator
-
(Callable)
Callable object to perform batch data pre-processing and forwarding. Get batch data from the dataloader, pre-process it if necessary, and use the module to forward it. This calibrator is called N + 1 times in the calibration and export stages.- Stage 1 is for calibration. In this stage, your dataloader is iterated, and the data is passed through the module to collect quantization statistics. The calibrator is called N times(N = len(dataloader)). At stage 1, if you did not pass the optional export_dataloader, the first batch returned by dataloader is saved and later used in stage 2. In this case, ensure the first batch is unchanged by calibrator or iteration side effects.
- Stage 2 is for exporting the quantized
TorchScript module. In this stage, the calibrator is only called
once with one batch of data. If you pass in an
export_dataloader, this export_dataloader is iterated, and only
the first batch is used. The program breaks out of iteration
after processing the first batch. If you did not pass in an
export_dataloader, the saved first batch from stage 1 is used.
Calibrator arguments:
- module
-
(torch.nn.Module)
Module for quantization. This is a modified version of the module you passed in, with the necessary mechanisms to collect data statistics. You should use this module instead of the original float model to forward your data. - batch_data
-
(Any)
Batch data returned from dataloader. - batch_index
-
(int)
Index of the batch. Use the device if necessary (torch.device) for forwarding. Currently only supports CPU.
Note: Extra positional and keyword arguments to quantize API will be forwarded to the calibrator. For more information, see Quantizing the Model.
- export_dataloader
-
(Iterable)
An optional dataloader for the export stage. The default value is None. If None, it uses the first batch saved from stage 1. - device
-
(torch.device)
Device to use for calibration. Currently only supports CPU. - output_dir
-
(str)
A temporary working directory. The default value is quantize_result. Some intermediary files are saved here. - bitwidth
-
(int)
Global quantization bit width. The default value is 8. - quant_config_file
-
(str)
Path to the quantizer configuration file. The default value is None. - args
- Extra positional arguments to pass to the calibrator.
- kwargs
- Extra keyword arguments to pass to the calibrator.
For more information on how to use on-the-fly quantization in WeGO, see WeGO examples .