剪枝算法的设计使其能在减少模型参数数量的同时尽可能降低精度损失。这是一个迭代进程,如下图所示。剪枝会导致精度损失,而通过训练对剩余权重进行微调则可恢复精度。经过训练有素且未剪枝的模型充当第一次迭代的输入,称为基线模型。该模型会加以剪枝和微调。接下来,从上一次迭代得到的微调模型即成为新的基线,并再次进行剪枝和微调。此进程会反复进行多次迭代,直到达到所需的稀疏度。迭代方法是必需的,因为在单次传递中,无法在维持精度的同时对具有高剪枝率的模型进行剪枝。单次迭代中如果剪枝的参数过多,那么精度损失可能过于剧烈,导致不可能通过微调来恢复精度。
重要: 每次迭代中会逐渐减少参数以改善微调阶段的精度。
利用迭代剪枝的进程,可以达到更高的剪枝率,同时模型性能不会出现显著损失。
图 1. 迭代剪枝
以下描述了迭代剪枝的 4 个主要阶段:
- 分析
- 对模型执行敏感度分析,判定最优剪枝策略。
- 剪枝
- 减少输入模型中的计算次数。
- 微调
- 重新训练已剪枝的模型以恢复精度。
- 变换
- 生成含更低权重的密集模型。