ResNet18 - 2.0 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2022-01-20
Version
2.0 English

This sample uses a ResNet18 model.

  1. Prepare an evaluation script for model analysis.
    import argparse
    import os
    import shutil
    import time
    import torch
    import torchvision.datasets as datasets
    import torchvision.transforms as transforms
    
    from torchvision.models.resnet import resnet18
    
    from pytorch_nndct import Pruner
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--data_dir',
        default='/scratch/workspace/dataset/imagenet/pytorch',
        help='Data set directory.')
    parser.add_argument(
        '--pretrained',
        default='/scratch/workspace/models/resnet18.pth',
        help='Trained model file path.')
    parser.add_argument(
        '--ratio',
        default=0.1,
        type=float,
        help='Desired pruning ratio. The larger this value, the smaller'
        'the model after pruning.')
    parser.add_argument(
        '--ana',
        default=False,
        type=bool,
        help='Whether to perform model analysis.')
    args, _ = parser.parse_known_args()
    
    class AverageMeter(object):
      """Computes and stores the average and current value"""
    
      def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()
    
      def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
      def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    
      def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
    def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
      torch.save(state, filename)
      if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
    
    class ProgressMeter(object):
    
      def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix
    
      def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))
    
      def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
    
    def accuracy(output, target, topk=(1,)):
      """Computes the accuracy over the k top predictions
        for the specified values of k"""
      with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
    
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
    
        res = []
        for k in topk:
          correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
          res.append(correct_k.mul_(100.0 / batch_size))
        return res
    
    def adjust_learning_rate(optimizer, epoch, lr):
      """Sets the learning rate to the initial LR decayed by every 2 epochs"""
      lr = lr * (0.1**(epoch // 2))
    
      for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    def train(train_loader, model, criterion, optimizer, epoch):
      batch_time = AverageMeter('Time', ':6.3f')
      data_time = AverageMeter('Data', ':6.3f')
      losses = AverageMeter('Loss', ':.4e')
      top1 = AverageMeter('Acc@1', ':6.2f')
      top5 = AverageMeter('Acc@5', ':6.2f')
      progress = ProgressMeter(
          len(train_loader), [batch_time, data_time, losses, top1, top5],
          prefix="Epoch: [{}]".format(epoch))
    
      # switch to train mode
      model.train()
    
      end = time.time()
      for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
    
        model = model.cuda()
        images = images.cuda()
        target = target.cuda()
    
        # compute output
        output = model(images)
        loss = criterion(output, target)
    
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
    
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
    
        if i % 10 == 0:
          progress.display(i)
    
    def evaluate(val_loader, model, criterion):
      batch_time = AverageMeter('Time', ':6.3f')
      losses = AverageMeter('Loss', ':.4e')
      top1 = AverageMeter('Acc@1', ':6.2f')
      top5 = AverageMeter('Acc@5', ':6.2f')
      progress = ProgressMeter(
          len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')
    
      # switch to evaluate mode
      model.eval()
    
      with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
          model = model.cuda()
          images = images.cuda(non_blocking=True)
          target = target.cuda(non_blocking=True)
    
          # compute output
          output = model(images)
          loss = criterion(output, target)
    
          # measure accuracy and record loss
          acc1, acc5 = accuracy(output, target, topk=(1, 5))
          losses.update(loss.item(), images.size(0))
          top1.update(acc1[0], images.size(0))
          top5.update(acc5[0], images.size(0))
    
          # measure elapsed time
          batch_time.update(time.time() - end)
          end = time.time()
    
          if i % 50 == 0:
            progress.display(i)
    
        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))
    
      return top1.avg, top5.avg
    
  2. Define a function used for model analysis.
    def ana_eval_fn(model, val_loader, loss_fn):
      return evaluate(val_loader, model, loss_fn)[1]
    
  3. Create a ResNet18 model and add the pruning APIs to perform pruning.
    if __name__ == '__main__':
      model = resnet18().cpu()
      model.load_state_dict(torch.load(args.pretrained))
    
      batch_size = 128
      workers = 4
      traindir = os.path.join(args.data_dir, 'train')
      valdir = os.path.join(args.data_dir, 'validation')
    
      normalize = transforms.Normalize(
          mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
      train_dataset = datasets.ImageFolder(
          traindir,
          transforms.Compose([
              transforms.RandomResizedCrop(224),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
              normalize,
          ]))
    
      train_loader = torch.utils.data.DataLoader(
          train_dataset,
          batch_size=batch_size,
          shuffle=True,
          num_workers=workers,
          pin_memory=True)
    
      val_dataset = datasets.ImageFolder(
          valdir,
          transforms.Compose([
              transforms.Resize(256),
              transforms.CenterCrop(224),
              transforms.ToTensor(),
              normalize,
          ]))
    
      val_loader = torch.utils.data.DataLoader(
          val_dataset,
          batch_size=batch_size,
          shuffle=False,
          num_workers=workers,
          pin_memory=True)
    
      criterion = torch.nn.CrossEntropyLoss().cuda()
    
      inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
      pruner = Pruner(model, inputs)
      if args.ana:
        pruner.ana(ana_eval_fn, args=(val_loader, criterion), gpus=[0, 1, 2, 3])
      model = pruner.prune(ratio=args.ratio)
      pruner.summary(model)
    
      lr = 1e-4
      optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-4)
    
      best_acc5 = 0
      epochs = 1
      for epoch in range(epochs):
        adjust_learning_rate(optimizer, epoch, lr)
    
        train(train_loader, model, criterion, optimizer, epoch)
    
        acc1, acc5 = evaluate(val_loader, model, criterion)
    
        # remember best acc@1 and save checkpoint
        is_best = acc5 > best_acc5
        best_acc5 = max(acc5, best_acc5)
    
        if is_best:
          torch.save(model.pruned_state_dict(), 'resnet18_sparse.pth')
          torch.save(model.state_dict(), 'resnet18_dense.pth')
  4. Download pretrained ResNet18 model:
    wget https://download.pytorch.org/models/resnet18-5c106cde.pth -O resnet18.pth
  5. Prepare ImageNet dataset. For more information, see http://image-net.org/download-images.
  6. Run the first round of pruning with model analysis:
    $ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18.pth --ratio 0.1  --ana True
  7. From the second round onwards, model analysis is no longer required. Increase the pruning ratio and use the sparse checkpoint saved from previous round as the pretrained weights.
    $ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18_sparse.pth.tar --ratio 0.2