vai_q_pytorch QAT - 3.5 English

Vitis AI User Guide (UG1414)

Document ID
Release Date
3.5 English

Assuming a pre-defined model architecture, follow the steps below to perform quantization aware training (QAT), using ResNet18 from torchvision as an example. The complete model definition is here.

  1. Check if there are non-module operations to be quantized. ResNet18 uses '+' to add two tensors. Replace them with pytorch_nndct.nn.modules.functional.Add.
  2. Check if there are modules to be called multiple times. Usually, such modules have no weights; the most common is the torch.nn.ReLu module. Define multiple such modules and then call them separately in a forward pass. The revised definition meeting these requirements is as follows:
    class BasicBlock(nn.Module):
      expansion = 1
      def __init__(self,
        super(BasicBlock, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
          raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
          raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
        # Use a functional module to replace '+'
        self.skip_add = functional.Add()
        # Additional defined module
        self.relu2 = nn.ReLU(inplace=True)
      def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
          identity = self.downsample(x)
        # Use function module instead of '+'
        # out += identity
        out = self.skip_add(out, identity)
        out = self.relu2(out)
        return out
  3. Insert QuantStub and DeQuantStub.

    Use QuantStub to quantize the inputs of the network and DeQuantStub to de-quantize the network outputs. Any sub-network from QuantStub to DeQuantStub in a forward pass is quantized. Multiple QuantStub-DeQuantStub pairs are allowed:

    class ResNet(nn.Module):
      def __init__(self,
        super(ResNet, self).__init__()
        if norm_layer is None:
          norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation iEachne:
          # each element in the tuple indicates if we should replace
          # the 2x2 stride with a dilated convolution instead
          replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
          raise ValueError(
              "replace_stride_with_dilation should be None "
              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.quant_stub = nndct_nn.QuantStub()
        self.dequant_stub = nndct_nn.DeQuantStub()
        for m in self.modules():
          if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
          elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to
        if zero_init_residual:
          for m in self.modules():
            if isinstance(m, Bottleneck):
              nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
              nn.init.constant_(m.bn2.weight, 0)
      def forward(self, x):
        x = self.quant_stub(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.dequant_stub(x)
        return x
  4. Use QAT APIs to create the quantizer and train the model:
    def _resnet(arch, block, layers, pretrained, progress, **kwargs):
      model = ResNet(block, layers, **kwargs)
      if pretrained:
        #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
        state_dict = torch.load(model_urls[arch])
      return model
    def resnet18(pretrained=False, progress=True, **kwargs):
      r"""ResNet-18 model from
        `"Deep Residual Learning for Image Recognition" <>'_
            pretrained (bool): If True, returns a model pre-trained on ImageNet
            progress (bool): If True, displays a progress bar of the download to stderr
      return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
    model = resnet18(pretrained=True)
    # Generate dummy inputs.
    input = torch.randn([batch_size, 3, 224, 224], dtype=torch.float32)
    # Create a quantizer
    from pytorch_nndct import QatProcessor
    qat_processor = QatProcessor(model, inputs, bitwidth=8)
    quantized_model = qat_processor.trainable_model()optimizer = torch.optim.Adam(
    # Use the optimizer to train del, just like a normal float model.
  5. Get the deployable model and test it.

    After completing QAT, convert the quantized model to a deployable model. The accuracy of the deployable model might differ slightly from that of the quantized model:

    output_dir = 'qat_result'
    deployable_model = qat_processor.to_deployable(quantized_model, output_dir)
    validate(val_loader, deployable_model, criterion, gpu)
  6. Export XMODEL from the deployable model.

    batch size=1 is mandatory for XMODEL compilation:

    # Use CPU mode to export xmodel.
    val_subset =, list(range(1)))
    subset_loader =
    # Must forward deployable model at least 1 iteration with batch_size=1
    for images, _ in subset_loader: