此 API 具有下列方法:
-
__init__(model, inputs)
- model
- 要进行剪枝的
torch.nn.Module
对象。 - inputs
- 单个 torch 或 torch 列表。张量用作为模型推断的输入。输入无需采用真实数据。可以采用随机生成的张量,与真实数据的形状和数据类型相同即可。
-
search(gpus=['0'], calibration_fn=None, calib_args=(), num_subnet=10, removal_ratio=0.5, excludes=[], eval_fn=None, eval_args=())
- gpus
- 要使用的 GPU 索引的元组或列表。如不设置,则使用默认 GPU。
- calibration_fn
- 可调用对象,取
torch.nn.Module
对象作为其首个实参。它用于为 BatchNormalization 层校准统计数据。 - calib_args
- 传递给 calibration_fn 的实参元组。
- num_subnet
- 满足 MAC 约束的子网络数量。
- removal_ratio
- 期望的 MAC 缩减百分比。
- excludes
- 需从剪枝中排除的模块。
- eval_fn
- 可调用对象,取
torch.nn.Module
对象作为其第一个实参,并返回评估得分。 - eval_args
- 传递给 eval_fn 的实参元组。
-
prune(mode='slim', index=None, removal_ratio=None, pruning_info_path=None)
- mode
- 以下值之一:['sparse', 'slim']。对于单步方法,应始终使用 'slim'(精简)模式。
- index
- 子网络索引。默认会自动选中最优子网络。
- removal_ratio
- 期望的 MAC 缩减百分比。
- pruning_info_path
- JSON 文件。为当前模型保存详细的剪枝信息。可生成含该文件和原始模型的精简模型。