Retraining a model is the same as training a baseline model.
optimizer = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=5e-4)
best_acc1 = 0
for epoch in range(args.epoches):
train(train_loader, model, criterion, optimizer, epoch)
acc1, acc5 = evaluate(val_loader, model, criterion)
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if is_best:
torch.save(model.state_dict(), 'model_pruned.pth')
# Sparse model has one more special method in iterative pruning.
if hasattr(model, 'slim_state_dict'):
torch.save(model.slim_state_dict(), 'model_slim.pth')