This method uses the sandwich
rule to jointly optimize all the OFA subnetworks. The sample_random_subnet()
function can be used to get a subnetwork. The
dynamic subnetwork can do a forward/backward pass.
In each training step, given a mini-batch of data, the sandwich rule samples a ‘max’ subnetwork, a ‘min’ subnetwork, and two random subnetworks. Each subnetwork does a separate forward/backward pass with the given data and then all the subnetworks update their parameters together.
# using sandwich rule and sampling subnet.
for i, (images, target) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# total subnets to be sampled
optimizer.zero_grad()
teacher_model.train()
with torch.no_grad():
soft_logits = teacher_model(images).detach()
for arch_id in range(4):
if arch_id == 0:
model, _ = ofa_pruner.sample_subnet(ofa_model,'max')
elif arch_id == 1:
model, _ = ofa_pruner.sample_subnet(ofa_model,'min')
else:
model, _ = ofa_pruner.sample_subnet(ofa_model,'random')
output = model(images)
loss = kd_loss(output, soft_logits) + cross_entropy_loss(output, target)
loss.backward()
torch.nn.utils.clip_grad_value_(ofa_model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()