此运行器适用于以迭代方式对模型进行结构化剪枝。此 API 具有下列方法:
-
__init__(model, input_specs)
创建新的
IterativePruningRunner
对象。- model
- 要进行剪枝的基线模型。模型应为
keras.Model
的实例。 - input_specs
- 单一或列表形式的
tf.TensorSpec
,用于表示模型输入规范。
-
ana(eval_fn, excludes=None, forced=False)
执行模型分析。分析结果保存在 '.vai' 目录中,除非
forced
设为 True,否则此缓存结果将在后续调用中直接使用。- eval_fn
- 可调用对象,取
keras.Model
对象作为其第一个实参,并返回评估得分。 - excludes
- 要从剪枝中排除的层名称或层实例的列表。
- forced
- 此项设为 True 时,会运行模型分析来替代缓存的分析结果。
-
prune(ratio=None, threshold=None, spec_path=None, excludes=None, mode='sparse')
对基线模型进行剪枝,并返回稀疏模型。剪枝程度可通过三种方式来指定:比率、阈值或剪枝规范。首选第一种方法;后两种方法更适合搭配手动微调来进行实验。
- ratio
- 期望的基线模型 FLOP(每秒浮点运算次数)缩减百分比。这是指导值,实际 FLOP 缩减与该值并非严格相等。
- threshold
- 基线模型与剪枝后的模型之间的相对模型性能损失比例。
- spec_path
- 用于模型剪枝的剪枝规范路径。
- excludes
- 要从剪枝中排除的层名称或层实例的列表。
- mode
- 基线模型的剪枝模式,剪枝后返回稀疏模型。
-
get_slim_model(spec_path=None)
从稀疏模型获取精简模型。默认使用最新剪枝规范来执行此变换。如果稀疏模型不是根据最新规范生成的,则可显式提供规范路径。
- spec_path
- 剪枝规范路径会将稀疏模型变换为精简模型。