This API has the following methods:
-
__init__(model, inputs)
- model
- A
torch.nn.Module
object to prune. - inputs
- A single or a list of torch. Tensor used as inputs for model inference. It does not need to be real data. It can be a randomly generated tensor of the same shape and data type as real data.
-
ana(eval_fn, args=(), gpus=None, excludes=None, forced=False)
- eval_fn
- Callable object that takes a
torch.nn.Module
object as its first argument and returns the evaluation score. - args
- A tuple of arguments that are passed to eval_fn.
- gpus
- A tuple or list of GPU indices to be used. If not set, the default GPU is used.
- excludes
- A list of node names or torch modules to be excluded from pruning.
- forced
- If False, skip model analysis and use cached result.
-
prune(removal_ratio=None, threshold=None, spec_path=None, excludes=None, mode='sparse')
- removal_ratio
- The expected percentage of MACs reduction.
- threshold
- Relative proportion of model performance loss that can be tolerated.
- spec_path
- Pre-defined pruning specification.
- excludes
- A list of node names or torch modules to be excluded from pruning.
- mode
- One of ['sparse', 'slim']. Always use 'sparse' in an iterative loop. A slim model is used for quantization-aware training.