Creating a Model
For simplicity, mobilenet_v2 from torchvision
is used here.
from torchvision.models.mobilenet import mobilenet_v2
model = mobilenet_v2(pretrained=True)
Creating an OFA Pruner
The pruner requires two arguments: the model to be pruned and inputs needed by the model for inference. Note that the input does not need to be real data. You can use randomly generated dummy data if it has the same shape and type as the real data.
import torch
from pytorch_nndct import OFAPruner
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = OFAPruner(model, inputs)
Generating an OFA Model
Call ofa_model()
to get an OFA model. This
method finds all the nn.Conv2d
(kernel_size > 1)
and nn.BatchNorm2d
modules, then replaces those
modules with DynamicConv2d
and DynamicBatchNorm2d
.
A list of pruning ratio is required to specify what the OFA model will be.
For each convolution layer in the OFA model, an arbitrary pruning ratio can be used in the output channel. The maximum and minimum values in this list represent the maximum and minimum compression rates of the model. Other values in the list represent the subnetworks to be optimized. By default, the pruning ratio is set to [0.5, 0.75, 0.1].
For a subnetwork sampled from the OFA model, the out channels of a convolution layer is one of the numbers in the pruning ratio list multiplied by its original number. For example, for a pruning ratio list of [0.5, 0.75, 0.1] and a convolution layer nn.Conv2d(16, 32, 5), the out channels of this layer in a sampled subnetwork is one of [0.5*32, 0.75*32, 1*32].
Because the first and last layers have a significant impact on network performance, they are commonly excluded from pruning.
ofa_model = ofa_pruner.ofa_model([0.5, 0.75, 1], excludes =['features.0.0', 'features.1.conv.0.0', 'features.18.0'])
Training an OFA Model
This method uses the sandwich rule to jointly optimize all the OFA subnetworks. The sample_random_subnet()
function can be used to get a
subnetwork. The dynamic subnetwork can do a forward/backward pass.
In each training step, given a mini-batch of data, the sandwich rule samples a ‘max’ subnetwork, a ‘min’ subnetwork, and two random subnetworks. Each subnetwork does a separate forward/backward pass with the given data and then all the subnetworks update their parameters together.
# using sandwich rule and sampling subnet.
for i, (images, target) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# total subnets to be sampled
optimizer.zero_grad()
teacher_model.train()
with torch.no_grad():
soft_logits = teacher_model(images).detach()
for arch_id in range(4):
if arch_id == 0:
model, _ = ofa_pruner.sample_subnet(ofa_model,'max')
elif arch_id == 1:
model, _ = ofa_pruner.sample_subnet(ofa_model,'min')
else:
model, _ = ofa_pruner.sample_subnet(ofa_model,'random')
output = model(images)
loss = kd_loss(output, soft_logits) + cross_entropy_loss(output, target)
loss.backward()
torch.nn.utils.clip_grad_value_(ofa_model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
Searching Constrained Subnetworks
pareto_global = ofa_pruner.run_evolutionary_search(ofa_model, calibration_fn, (train_loader,) eval_fn, (val_loader,), min_flops=230, max_flops=250)
ofa_pruner.save_subnet_config(pareto_global, 'pareto_global.txt')
The searching result looks like the following:
{
"230": {
"net_id": "net_evo_0_crossover_0",
"mode": "evaluate",
"acc1": 69.04999542236328,
"flops": 228.356192,
"params": 3.096728,
"subnet_setting": [...]
}
"240": {
"net_id": "net_evo_0_mutate_1",
"mode": "evaluate",
"acc1": 69.22000122070312,
"flops": 243.804128,
"params": 3.114,
"subnet_setting": [...]
}}
Getting a Subnetwork
Call get_static_subnet()
to get a specific
subnetwork.
pareto_global = ofa_pruner.load_subnet_config('pareto_global.txt')
static_subnet, static_subnet_config, flops, params = \ ofa_pruner.get_static_subnet(ofa_model, pareto_global['240']['subnet_setting'])