Retraining a Model - 2.5 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2022-06-15
Version
2.5 English

Retraining a model is the same as training a baseline model.

optimizer = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=5e-4)
best_acc1 = 0 

for epoch in range(args.epoches):
  train(train_loader, model, criterion, optimizer, epoch)
  acc1, acc5 = evaluate(val_loader, model, criterion)

  is_best = acc1 > best_acc1
  best_acc1 = max(acc1, best_acc1)
  if is_best:
    torch.save(model.state_dict(), 'model_pruned.pth')
    # Sparse model has one more special method in iterative pruning.
    if hasattr(model, 'slim_state_dict'):
      torch.save(model.slim_state_dict(), 'model_slim.pth')