This sample uses a ResNet18 model.
- Prepare an evaluation script for model
analysis.
import argparse import os import shutil import time import torch import torchvision.datasets as datasets import torchvision.transforms as transforms from torchvision.models.resnet import resnet18 from pytorch_nndct import Pruner parser = argparse.ArgumentParser() parser.add_argument( '--data_dir', default='/scratch/workspace/dataset/imagenet/pytorch', help='Data set directory.') parser.add_argument( '--pretrained', default='/scratch/workspace/models/resnet18.pth', help='Trained model file path.') parser.add_argument( '--ratio', default=0.1, type=float, help='Desired pruning ratio. The larger this value, the smaller' 'the model after pruning.') parser.add_argument( '--ana', default=False, type=bool, help='Whether to perform model analysis.') args, _ = parser.parse_known_args() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best.pth.tar') class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def adjust_learning_rate(optimizer, epoch, lr): """Sets the learning rate to the initial LR decayed by every 2 epochs""" lr = lr * (0.1**(epoch // 2)) for param_group in optimizer.param_groups: param_group['lr'] = lr def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) model = model.cuda() images = images.cuda() target = target.cuda() # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: progress.display(i) def evaluate(val_loader, model, criterion): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): model = model.cuda() images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % 50 == 0: progress.display(i) # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format( top1=top1, top5=top5)) return top1.avg, top5.avg
- Define a function used for model
analysis.
def ana_eval_fn(model, val_loader, loss_fn): return evaluate(val_loader, model, loss_fn)[1]
- Create a ResNet18 model and add the pruning APIs to perform
pruning.
if __name__ == '__main__': model = resnet18().cpu() model.load_state_dict(torch.load(args.pretrained)) batch_size = 128 workers = 4 traindir = os.path.join(args.data_dir, 'train') valdir = os.path.join(args.data_dir, 'validation') normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) val_dataset = datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True) criterion = torch.nn.CrossEntropyLoss().cuda() inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32) pruner = Pruner(model, inputs) if args.ana: pruner.ana(ana_eval_fn, args=(val_loader, criterion), gpus=[0, 1, 2, 3]) model = pruner.prune(ratio=args.ratio) pruner.summary(model) lr = 1e-4 optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-4) best_acc5 = 0 epochs = 1 for epoch in range(epochs): adjust_learning_rate(optimizer, epoch, lr) train(train_loader, model, criterion, optimizer, epoch) acc1, acc5 = evaluate(val_loader, model, criterion) # remember best acc@1 and save checkpoint is_best = acc5 > best_acc5 best_acc5 = max(acc5, best_acc5) if is_best: torch.save(model.pruned_state_dict(), 'resnet18_sparse.pth') torch.save(model.state_dict(), 'resnet18_dense.pth')
- Download pretrained ResNet18
model:
wget https://download.pytorch.org/models/resnet18-5c106cde.pth -O resnet18.pth
- Prepare ImageNet dataset. For more information, see http://image-net.org/download-images.
- Run the first round of pruning with model
analysis:
$ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18.pth --ratio 0.1 --ana True
- From the second round onwards, model analysis is no longer required. Increase the pruning
ratio and use the sparse checkpoint saved from previous round as the pretrained
weights.
$ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18_sparse.pth.tar --ratio 0.2