Write code for model training.
import argparse
import os
import shutil
import time
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.models.resnet import resnet18
from pytorch_nndct import Pruner
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir',
default='/scratch/workspace/dataset/imagenet/pytorch',
help='Data set directory.')
parser.add_argument(
'--pretrained',
default='/scratch/workspace/models/resnet18.pth',
help='Trained model file path.')
parser.add_argument(
'--ratio',
default=0.1,
type=float,
help='Desired pruning ratio. The larger this value, the smaller'
'the model after pruning.')
parser.add_argument(
'--ana',
default=False,
type=bool,
help='Whether to perform model analysis.')
args, _ = parser.parse_known_args()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions
for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def adjust_learning_rate(optimizer, epoch, lr):
"""Sets the learning rate to the initial LR decayed by every 2 epochs"""
lr = lr * (0.1**(epoch // 2))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(train_loader), [batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i, (images, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
model = model.cuda()
images = images.cuda()
target = target.cuda()
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
progress.display(i)
def evaluate(val_loader, model, criterion):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
model = model.cuda()
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 50 == 0:
progress.display(i)
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
top1=top1, top5=top5))
return top1.avg, top5.avg
Define a function used for model analysis.
def ana_eval_fn(model, val_loader, loss_fn):
return evaluate(val_loader, model, loss_fn)[1]
Create a resnet18 model and add pruning APIs to perform pruning.
if __name__ == '__main__':
model = resnet18().cpu()
model.load_state_dict(torch.load(args.pretrained))
batch_size = 128
workers = 4
traindir = os.path.join(args.data_dir, 'train')
valdir = os.path.join(args.data_dir, 'validation')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True)
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True)
criterion = torch.nn.CrossEntropyLoss().cuda()
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = Pruner(model, inputs)
if args.ana:
pruner.ana(ana_eval_fn, args=(val_loader, criterion), gpus=[0, 1, 2, 3])
model = pruner.prune(ratio=args.ratio)
pruner.summary(model)
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-4)
best_acc5 = 0
epochs = 1
for epoch in range(epochs):
adjust_learning_rate(optimizer, epoch, lr)
train(train_loader, model, criterion, optimizer, epoch)
acc1, acc5 = evaluate(val_loader, model, criterion)
# remember best acc@1 and save checkpoint
is_best = acc5 > best_acc5
best_acc5 = max(acc5, best_acc5)
if is_best:
torch.save(model.pruned_state_dict(), 'resnet18_sparse.pth')
torch.save(model.state_dict(), 'resnet18_dense.pth')
Download pretrained ResNet18 model:
wget https://download.pytorch.org/models/resnet18-5c106cde.pth -O resnet18.pth
Prepare ImageNet dataset, see http://image-net.org/download-images.
Run the first round of pruning with model analysis:
$ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18.pth --ratio 0.1 --ana True
Starting from the second round, pruning no longer requires a model analysis, you just need to increase the pruning ratio and use the sparse checkpoint saved from previous round as the pretrained weights.
$ python -u resnet18_pruning.py --data_dir imagenet_dir --pretrained resnet18_sparse.pth.tar --ratio 0.2