The steps for pruning a model according to this pruning method are as follows:
Creating a Model
For simplicity, ResNet18 from torchvision
is used here. In real life applications, the process of
creating a model can be complicated.
from torchvision.models.resnet import resnet18
model = resnet18(pretrained=True)
Creating a Pruning Runner
Import modules and prepare input signature:
from pytorch_nndct import get_pruning_runner
# The input signature should have the same shape and dtype as the model input.
input_signature = torch.randn([1, 3, 224, 224], dtype=torch.float32)
Create an iterative pruning runner:
runner = get_pruning_runner(model, input_signature, 'iterative')
Or, a one-step pruning runner:
runner = get_pruning_runner(model, input_signature, 'one_step')
Pruning a Model
- Iterative Pruning
- The method includes two stages: model analysis and pruned model generation.
After the model analysis is completed, analysis result is saved in the file
named .vai/xxx.sens. You can prune a
model iteratively using this file. In other words, you should prune the
model to the target sparsity gradually to avoid the failure to improve the
model performance in the retraining stage that is caused by setting a too
high pruning ratio.
- Define an evaluation function. The function must
take a model as its first argument and return a
score.
def eval_fn(model, dataloader): top1 = AverageMeter('Acc@1', ':6.2f') model.eval() with torch.no_grad(): for i, (images, targets) in enumerate(dataloader): images = images.cuda() targets = targets.cuda() outputs = model(images) acc1, _ = accuracy(outputs, targets, topk=(1, 5)) top1.update(acc1[0], images.size(0)) return top1.avg
- Run model analysis and get a pruned
model.
runner.ana(eval_fn, args=(val_loader,)) model = pruning_runner.prune(ratio=0.2)
Run analysis only once for the same model. You can prune the model iteratively without re-running analysis because there is only one pruned model generated for a specific pruning ratio. The subnetwork obtained by pruning may not be very good because an approximate algorithm is used to generate this unique pruned model according to the analysis result. The one-step pruning method can generate a better subnetwork.
- Define an evaluation function. The function must
take a model as its first argument and return a
score.
- One-step Pruning
- The method also include two stages: adaptive-BN-based searching for pruning
strategy and pruned model generation. After searching, a file named
.vai/xxx.search is generated in
which the search result (pruning strategies and corresponding evaluation
scores) is stored. You can get the final pruned model in one-step.
num_subnet
provides the number of candidate subnetworks satisfying the sparsity requirement to be searched. The best subnetwork can be selected from these candidates. The higher the value, the longer it takes to search, but the higher the probability of finding a better subnetwork.# Adaptive-BN-based searching for pruning strategy. 'calibration_fn' is a function for calibrating BN layer's statistics. runner.search(gpus=['0'], calibration_fn=calibration_fn, calib_args=(val_loader,), eval_fn=eval_fn, eval_args=(val_loader,), num_subnet=1000, sparsity=0.7) model = runner.prune()
The
eval_fn
is the same with iterative pruning method. Acalibration_fn
function that implements adaptive-BN is shown in the following example code. You should define your code similarly.def calibration_fn(model, dataloader, number_forward=100): model.train() with torch.no_grad(): for index, (images, target) in enumerate(dataloader): images = images.cuda() model(images) if index > number_forward: break
The one-step pruning method has several advantages over the iterative approach.
- The generated pruned models are more accurate. All subnetworks that meet the requirements are evaluated.
- The workflow is simpler because you can obtain the final pruned model in one step without iterations.
- Retraining a slim model is faster than a sparse model.
There are two disadvantages to one-step pruning: One is that the random generation of pruning strategy is unstable. The other is that the subnetwork searching must be performed once for every pruning ratio.
Retraining a Model
Retraining a model is the same as training a baseline model.
optimizer = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=5e-4)
best_acc1 = 0
for epoch in range(args.epoches):
train(train_loader, model, criterion, optimizer, epoch)
acc1, acc5 = evaluate(val_loader, model, criterion)
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if is_best:
torch.save(model.state_dict(), 'model_pruned.pth')
# Sparse model has one more special method in iterative pruning.
if hasattr(model, 'slim_state_dict'):
torch.save(model.slim_state_dict(), 'model_slim.pth')
Getting a Pruned Model
method = 'iterative' # or 'one_step'
runner = get_pruning_runner(model, input_signature, method)
slim_model = runner.prune(ratio=0.2, mode='slim')
slim_model.load_state_dict(torch.load('model_pruned.pth'))