从先前微调阶段保存的检查点执行加载。增大剪枝率值,提升稀疏度等级。执行每个剪枝步骤后,对此稀疏模型进行微调。重复此剪枝和微调循环,直至稀疏度达到目标值,或者直至观察到评估性能等级显著降级为止。
model.load_weights("model_sparse_0.2")
input_shape = [28, 28, 1]
input_spec = tf.TensorSpec((1, *input_shape), tf.float32)
runner = IterativePruningRunner(model, input_spec)
sparse_model = runner.prune(ratio=0.5)