Once-for-All (OFA) - 2.0 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
Release Date
2.0 English

Creating a Model

For simplicity, mobilenet_v2 from torchvision is used here.

from torchvision.models.mobilenet import mobilenet_v2
model = mobilenet_v2(pretrained=True)

Creating an OFA Pruner

The pruner requires two arguments: the model to be pruned and inputs needed by the model for inference. Note that the input does not need to be real data. You can use randomly generated dummy data if it has the same shape and type as the real data.

import torch
from pytorch_nndct import OFAPruner

inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = OFAPruner(model, inputs)

Generating an OFA Model

Call ofa_model() to get an OFA model. This method finds all the nn.Conv2d (kernel_size > 1) and nn.BatchNorm2d modules, then replaces those modules with DynamicConv2d and DynamicBatchNorm2d.

A list of pruning ratio is required to specify what the OFA model will be.

For each convolution layer in the OFA model, an arbitrary pruning ratio can be used in the output channel. The maximum and minimum values in this list represent the maximum and minimum compression rates of the model. Other values in the list represent the subnetworks to be optimized. By default, the pruning ratio is set to [0.5, 0.75, 0.1].

For a subnetwork sampled from the OFA model, the out channels of a convolution layer is one of the numbers in the pruning ratio list multiplied by its original number. For example, for a pruning ratio list of [0.5, 0.75, 0.1] and a convolution layer nn.Conv2d(16, 32, 5), the out channels of this layer in a sampled subnetwork is one of [0.5*32, 0.75*32, 1*32].

Because the first and last layers have a significant impact on network performance, they are commonly excluded from pruning.

ofa_model = ofa_pruner.ofa_model([0.5, 0.75, 1], excludes =['features.0.0', 'features.1.conv.0.0', 'features.18.0'])

Training an OFA Model

This method uses the sandwich rule to jointly optimize all the OFA subnetworks. The sample_random_subnet() function can be used to get a subnetwork. The dynamic subnetwork can do a forward/backward pass.

In each training step, given a mini-batch of data, the sandwich rule samples a ‘max’ subnetwork, a ‘min’ subnetwork, and two random subnetworks. Each subnetwork does a separate forward/backward pass with the given data and then all the subnetworks update their parameters together.

# using sandwich rule and sampling subnet.
for i, (images, target) in enumerate(train_loader):

  images = images.cuda(non_blocking=True)
  target = target.cuda(non_blocking=True)

  # total subnets to be sampled

  with torch.no_grad():
    soft_logits = teacher_model(images).detach()

  for arch_id in range(4):
    if arch_id == 0:
      model, _ = ofa_pruner.sample_subnet(ofa_model,'max')
    elif arch_id == 1:
      model, _ = ofa_pruner.sample_subnet(ofa_model,'min')
      model, _ = ofa_pruner.sample_subnet(ofa_model,'random') 

    output = model(images)

    loss = kd_loss(output, soft_logits) + cross_entropy_loss(output, target) 

  torch.nn.utils.clip_grad_value_(ofa_model.parameters(), 1.0)

Searching Constrained Subnetworks

After the training is completed, you can conduct an evolutionary search based on the neural-network-twins to get a subnetwork with the best trade-offs between FLOPs and accuracy using a minimum and maximum FLOPs range.
pareto_global = ofa_pruner.run_evolutionary_search(ofa_model, calibration_fn, (train_loader,) eval_fn, (val_loader,), min_flops=230, max_flops=250)

ofa_pruner.save_subnet_config(pareto_global, 'pareto_global.txt')

The searching result looks like the following:

"230": { 
    "net_id": "net_evo_0_crossover_0", 
    "mode": "evaluate",
    "acc1": 69.04999542236328,
    "flops": 228.356192,
    "params": 3.096728,
    "subnet_setting": [...]
"240": {
    "net_id": "net_evo_0_mutate_1",
    "mode": "evaluate",
    "acc1": 69.22000122070312,
    "flops": 243.804128,
    "params": 3.114,
    "subnet_setting": [...]

Getting a Subnetwork

Call get_static_subnet() to get a specific subnetwork.

pareto_global = ofa_pruner.load_subnet_config('pareto_global.txt')
static_subnet, static_subnet_config, flops, params = \ ofa_pruner.get_static_subnet(ofa_model, pareto_global['240']['subnet_setting'])