To run model analysis, a evaluation function needs to be passed to the pruner.ana()
function. A limitation to this evaluation function is
that the first argument must be the model to be evaluated. Generally, the existing evaluation
function does not meet the requirement and you must define a wrapper function as shown
below.
Consider this as your evaluation function:
def evaluate(val_loader, model, criterion):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
model = model.cuda()
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 50 == 0:
progress.display(i)
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
top1=top1, top5=top5))
return top1.avg, top5.avg
Define a wrapper to meet the evaluation function requirements:
def ana_eval_fn(model, val_loader, loss_fn):
return evaluate(val_loader, model, loss_fn)[1]
Then, call ana()
method with the function
defined above as the first argument.
pruner.ana(ana_eval_fn, args=(val_loader, criterion))
Here, the ‘args’
is the tuple of arguments
starting from the second argument required by ‘ana_eval_fn’
.