此 API 具有下列方法:
-
__init__(model, inputs)
- model
- 要进行剪枝的
torch.nn.Module
对象。 - inputs
- 单个 torch 或 torch 列表。张量用作为模型推断的输入。输入无需采用真实数据。可以采用随机生成的张量,与真实数据的形状和数据类型相同即可。
-
ana(eval_fn, args=(), gpus=None, excludes=None, forced=False)
- eval_fn
- 可调用对象,取
torch.nn.Module
对象作为其第一个实参,并返回评估得分。 - args
- 传递给 eval_fn 的实参元组。
- gpus
- 要使用的 GPU 索引的元组或列表。如不设置,则使用默认 GPU。
- excludes
- 要从剪枝中排除的节点名称或 torch 模块的列表。
- forced
- 如为 False,则跳过模型分析并使用缓存的结果。
-
prune(removal_ratio=None, threshold=None, spec_path=None, excludes=None, mode='sparse')
- removal_ratio
- 期望的 MAC 缩减百分比。
- threshold
- 可承受的模型性能损失的相对比例。
- spec_path
- 预定义的剪枝规范。
- excludes
- 要从剪枝中排除的节点名称或 torch 模块的列表。
- mode
- 以下值之一:['sparse', 'slim']。在迭代循环中请始终使用 'sparse'。slim(精简)模型用于量化感知训练。