There are two ways to generate a final slim model. Usually the slim model is used for quantizing or evaluating directly.
With Pruning API
method = 'iterative' # or 'one_step'
runner = get_pruning_runner(model, input_signature, method)
slim_model = runner.prune(removal_ratio=0.2, mode='slim')
slim_model.load_state_dict(torch.load('model_pruned.pth'))
Without Pruning API
def slim_model_from_state_dict_and_pruning_info(model, state_dict, json_path):
"""Modify modules according to pruning information saved in json file and
load the state dict to the model.
Args:
model: An torch.nn.Module instance to load state dict.
state_dict: A state dict to be loaded.
json_path: A json file to save pruning information.
Returns:
A modified model generated by pruning information saved in json file.
"""
with open(json_path, 'r') as f:
pruning_info = json.load(f)
tensorname_to_nodename = pruning_info['MapOfTensornameToNodename']
def change_conv_module(module, tensor_name):
if module.groups == 1:
if pruning_info[node_name]['in_dim'] != 0:
module.in_channels = pruning_info[node_name]['in_dim']
if pruning_info[node_name]['out_dim'] != 0:
module.out_channels = pruning_info[node_name]['out_dim']
else:
module.groups = pruning_info[node_name]['out_dim']
if pruning_info[node_name]['in_dim'] != 0:
module.in_channels = pruning_info[node_name]['in_dim']
if pruning_info[node_name]['out_dim'] != 0:
module.out_channels = pruning_info[node_name]['out_dim']
def change_tensor(tensor_name):
node_name = tensorname_to_nodename[tensor_name]
tensor = state_dict[tensor_name]
removed_inputs = pruning_info[node_name]['removed_inputs']
removed_outputs = pruning_info[node_name]['removed_outputs']
if tensor.ndim == 4 or tensor.ndim == 2:
new_tensor = np.delete(tensor, removed_outputs, axis=0)
new_tensor = np.delete(new_tensor, removed_inputs, axis=1)
elif tensor.ndim == 1:
new_tensor = np.delete(tensor, removed_outputs, axis=0)
else:
pass
return new_tensor
for key, module in model.named_modules():
weight_key = key + '.weight'
if weight_key in tensorname_to_nodename:
node_name = tensorname_to_nodename[weight_key]
bias_key = key + '.bias'
if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d)):
state_dict[weight_key] = change_tensor(weight_key)
module.weight = nn.Parameter(state_dict[weight_key])
state_dict[bias_key] = change_tensor(bias_key)
module.bias = nn.Parameter(state_dict[bias_key])
running_mean_key = key + '.running_mean'
state_dict[running_mean_key] = change_tensor(running_mean_key)
module.running_mean = state_dict[running_mean_key]
running_var_key = key + '.running_var'
state_dict[running_var_key] = change_tensor(running_var_key)
module.running_var = state_dict[running_var_key]
if pruning_info[node_name]['out_dim'] != 0:
module.num_features = pruning_info[node_name]['out_dim']
elif isinstance(module, (nn.Conv2d, nn.Conv3d)):
change_conv_module(module, node_name)
state_dict[weight_key] = change_tensor(weight_key)
module.weight = nn.Parameter(state_dict[weight_key])
if bias_key in state_dict:
state_dict[bias_key] = change_tensor(bias_key)
module.bias = nn.Parameter(state_dict[bias_key])
elif isinstance(module, (nn.ConvTranspose2d, nn.ConvTranspose3d)):
change_conv_module(module, node_name)
state_dict[weight_key] = change_tensor(weight_key)
module.weight = nn.Parameter(state_dict[weight_key])
if bias_key in state_dict:
state_dict[bias_key] = change_tensor(bias_key)
module.bias = nn.Parameter(state_dict[bias_key])
elif isinstance(module, nn.Linear):
state_dict[weight_key] = change_tensor(weight_key)
module.weight = nn.Parameter(state_dict[weight_key])
if bias_key in state_dict:
state_dict[bias_key] = change_tensor(bias_key)
module.bias = nn.Parameter(state_dict[bias_key])
if pruning_info[node_name]['out_dim'] != 0:
module.out_features = pruning_info[node_name]['out_dim']
if pruning_info[node_name]['in_dim'] != 0:
module.in_features = pruning_info[node_name]['in_dim']
else:
pass
model.load_state_dict(state_dict)
return model
slim_model = slim_model_from_state_dict_and_pruning_info(model, state_dict, json_path)