Implements channel pruning at the module level.
Arguments
Pruner(module, inputs)
Create a new pruner object.
- module
- A
torch.nn.Module
object to be pruned. - inputs
- The inputs of the module.
Methods
-
ana(eval_fn, args=(), gpus=None)
Performs model analysis.
- eval_fn
- Callable object that takes a
torch.nn.Module
object as its first argument and returns the evaluation score. - args
- A tuple of arguments that will be passed to
eval_fn
. - gpus
- A tuple or list of GPU indices used for model analysis. If not set, the default GPU will be used.
-
prune(ratio=None, threshold=None, excludes=None, output_script='graph.py')
Pruning the network by a given ratio or threshold returns an
‘torch.nn.Module’
object. The difference between the returned object and the torch native module is that it has one more method named‘pruned_state_dict()’
, by which you can get parameters of the pruned dense model. The weights returned by‘pruned_state_dict()’
can be loaded into the model created with Python in the ‘output_script’ file.- ratio
- The expected percentage of FLOPs reduction. This is an approximation. The actual percentage may not drop strictly to this value after pruning.
- threshold
- Relative proportion of model performance loss that can be tolerated.
- excludes
- Modules that need to prevent from pruning.
- output_script
- Filepath that saves the generated script used for rebuilding model.