Training an OFA Model - 3.5 English

Vitis AI User Guide (UG1414)

Document ID
UG1414
Release Date
2023-09-28
Version
3.5 English

This method uses the sandwich rule to jointly optimize all the OFA subnetworks. The sample_random_subnet() function is used to get a subnetwork. The dynamic subnetwork is used in both forward and back propagation.

In each training step, given a mini batch of data, the sandwich rule samples a ‘max’ subnetwork, a ‘min’ subnetwork, and two random subnetworks. Each subnetwork does a separate forward/backward pass with the given data, and all the subnetworks update their parameters together.

# using sandwich rule and sampling subnet.
for i, (images, target) in enumerate(train_loader):

  images = images.cuda(non_blocking=True)
  target = target.cuda(non_blocking=True)

  # total subnets to be sampled
  optimizer.zero_grad()

  teacher_model.train()
  with torch.no_grad():
    soft_logits = teacher_model(images).detach()

  for arch_id in range(4):
    if arch_id == 0:
      model, _ = ofa_pruner.sample_subnet(ofa_model,'max')
    elif arch_id == 1:
      model, _ = ofa_pruner.sample_subnet(ofa_model,'min')
    else:
      model, _ = ofa_pruner.sample_subnet(ofa_model,'random') 

    output = model(images)

    loss = kd_loss(output, soft_logits) + cross_entropy_loss(output, target) 
    loss.backward()

  torch.nn.utils.clip_grad_value_(ofa_model.parameters(), 1.0)
  optimizer.step()
  lr_scheduler.step()