Generating a Pruned Model - 2.5 English

Vitis AI Optimizer User Guide (UG1333)

Document ID
UG1333
Release Date
2022-06-15
Version
2.5 English

There are two ways to generate a final slim model. Usually the slim model is used for quantizing or evaluating directly.

With Pruning API

method = 'iterative' # or 'one_step'
runner = get_pruning_runner(model, input_signature, method)
slim_model = runner.prune(removal_ratio=0.2, mode='slim')
slim_model.load_state_dict(torch.load('model_pruned.pth'))

Without Pruning API

def slim_model_from_state_dict_and_pruning_info(model, state_dict, json_path):
  """Modify modules according to pruning information saved in json file and
  load the state dict to the model.


  Args:
  model: An torch.nn.Module instance to load state dict.
  state_dict: A state dict to be loaded.
  json_path: A json file to save pruning information.


  Returns:
  A modified model generated by pruning information saved in json file.


  """
  with open(json_path, 'r') as f:
    pruning_info = json.load(f)
  tensorname_to_nodename = pruning_info['MapOfTensornameToNodename']

  def change_conv_module(module, tensor_name):
    if module.groups == 1:
      if pruning_info[node_name]['in_dim'] != 0:
        module.in_channels = pruning_info[node_name]['in_dim']
      if pruning_info[node_name]['out_dim'] != 0:
        module.out_channels = pruning_info[node_name]['out_dim']
    else:
      module.groups = pruning_info[node_name]['out_dim']
      if pruning_info[node_name]['in_dim'] != 0:
        module.in_channels = pruning_info[node_name]['in_dim']
      if pruning_info[node_name]['out_dim'] != 0:
        module.out_channels = pruning_info[node_name]['out_dim']

  def change_tensor(tensor_name):
    node_name = tensorname_to_nodename[tensor_name]
    tensor = state_dict[tensor_name]
    removed_inputs = pruning_info[node_name]['removed_inputs']
    removed_outputs = pruning_info[node_name]['removed_outputs']
    if tensor.ndim == 4 or tensor.ndim == 2:
      new_tensor = np.delete(tensor, removed_outputs, axis=0)
      new_tensor = np.delete(new_tensor, removed_inputs, axis=1)
    elif tensor.ndim == 1:
      new_tensor = np.delete(tensor, removed_outputs, axis=0)
    else:
      pass
    return new_tensor

  for key, module in model.named_modules():
    weight_key = key + '.weight'
    if weight_key in tensorname_to_nodename:
      node_name = tensorname_to_nodename[weight_key]
      bias_key = key + '.bias'
    if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d)):
      state_dict[weight_key] = change_tensor(weight_key)
      module.weight = nn.Parameter(state_dict[weight_key])
      state_dict[bias_key] = change_tensor(bias_key)
      module.bias = nn.Parameter(state_dict[bias_key])
      running_mean_key = key + '.running_mean'
      state_dict[running_mean_key] = change_tensor(running_mean_key)
      module.running_mean = state_dict[running_mean_key]
      running_var_key = key + '.running_var'
      state_dict[running_var_key] = change_tensor(running_var_key)
      module.running_var = state_dict[running_var_key]
      if pruning_info[node_name]['out_dim'] != 0:
        module.num_features = pruning_info[node_name]['out_dim']
    elif isinstance(module, (nn.Conv2d, nn.Conv3d)):
      change_conv_module(module, node_name)
      state_dict[weight_key] = change_tensor(weight_key)
      module.weight = nn.Parameter(state_dict[weight_key])
      if bias_key in state_dict:
        state_dict[bias_key] = change_tensor(bias_key)
        module.bias = nn.Parameter(state_dict[bias_key])
    elif isinstance(module, (nn.ConvTranspose2d, nn.ConvTranspose3d)):
      change_conv_module(module, node_name)
      state_dict[weight_key] = change_tensor(weight_key)
      module.weight = nn.Parameter(state_dict[weight_key])
      if bias_key in state_dict:
        state_dict[bias_key] = change_tensor(bias_key)
        module.bias = nn.Parameter(state_dict[bias_key])
    elif isinstance(module, nn.Linear):
      state_dict[weight_key] = change_tensor(weight_key)
      module.weight = nn.Parameter(state_dict[weight_key])
      if bias_key in state_dict:
        state_dict[bias_key] = change_tensor(bias_key)
        module.bias = nn.Parameter(state_dict[bias_key])
      if pruning_info[node_name]['out_dim'] != 0:
        module.out_features = pruning_info[node_name]['out_dim']
      if pruning_info[node_name]['in_dim'] != 0:
        module.in_features = pruning_info[node_name]['in_dim']
    else:
      pass
  model.load_state_dict(state_dict)
  return model

slim_model = slim_model_from_state_dict_and_pruning_info(model, state_dict, json_path)