pytorch_nndct.IterativePruningRunner - 3.5 English

Vitis AI User Guide (UG1414)

Document ID
UG1414
Release Date
2023-09-28
Version
3.5 English

This API has the following methods:

  • __init__(model, inputs)
    model
    A torch.nn.Module object to prune.
    inputs
    A single or a list of torch. Tensor used as inputs for model inference. It does not need to be real data. It can be a randomly generated tensor of the same shape and data type as real data.
  • ana(eval_fn, args=(), gpus=None, excludes=None, forced=False)
    eval_fn
    Callable object that takes a torch.nn.Module object as its first argument and returns the evaluation score.
    args
    A tuple of arguments that are passed to eval_fn.
    gpus
    A tuple or list of GPU indices to be used. If not set, the default GPU is used.
    excludes
    A list of node names or torch modules to be excluded from pruning.
    forced
    If False, skip model analysis and use cached result.
  • prune(removal_ratio=None, threshold=None, spec_path=None, excludes=None, mode='sparse')
    removal_ratio
    The expected percentage of MACs reduction.
    threshold
    Relative proportion of model performance loss that can be tolerated.
    spec_path
    Pre-defined pruning specification.
    excludes
    A list of node names or torch modules to be excluded from pruning.
    mode
    One of ['sparse', 'slim']. Always use 'sparse' in an iterative loop. A slim model is used for quantization-aware training.