This API has the following methods:
-
__init__(model, inputs)
- model
- A
torch.nn.Module
object to prune. - inputs
- A single or a list of torch.Tensor used as inputs for model inference. It does not need to be real data. It can be a randomly generated tensor of the same shape and data type as the real data.
-
search(gpus=['0'], calibration_fn=None, calib_args=(), num_subnet=10, removal_ratio=0.5, excludes=[], eval_fn=None, eval_args=())
- gpus
- A tuple or list of GPU indices to be used. If not set, the default GPU will be used.
- calibration_fn
- Callable object that takes a torch.nn.Module object as its first argument. It is used for calibrating statistics of the BatchNormalization layers.
- calib_args
- A tuple of arguments that is passed to calibration_fn.
- num_subnet
- Number of subnetworks that satisfy the flops constraint.
- removal_ratio
- The expected percentage of MACs reduction.
- excludes
- Modules that need to exclude from pruning.
- eval_fn
- Callable object that takes a
torch.nn.Module
object as its first argument and returns the evaluation score. - eval_args
- A tuple of arguments that is passed to eval_fn.
-
prune(mode='slim', index=None, removal_ratio=None, pruning_info_path=None)
- mode
- One of ['sparse', 'slim']. Should always use 'slim' mode for one-step method.
- index
- Subnetwork index. By default, the optimal subnetwork is selected automatically.
- removal_ratio
- The expected percentage of MACs reduction.
- pruning_info_path
- A .json file. Save detailed pruning information for current model. A slim model can be generated with the file and origin model.