完成迭代剪枝后,即可生成稀疏模型,此模型所含参数数量与原始模型相同,但其中大量参数现已设置为零 (0)。
调用 get_slim_model()
移除稀疏模型中含 0 值的参数并生成最终剪枝后的模型:
model.load_weights("model_sparse_0.5")
input_shape = [28, 28, 1]
input_spec = tf.TensorSpec((1, *input_shape), tf.float32)
runner = IterativePruningRunner(model, input_spec)
slim_model = runner.get_slim_model()
默认情况下,运行器使用最新剪枝规范来生成精简模型。您可以看到,最新规范文件含如下命令:
$ cat .vai/latest_spec
$ ".vai/mnist_ratio_0.5.spec"
如果此文件与您的稀疏模型不匹配,您可显式指定要使用的文件路径:
slim_model = runner.get_slim_model(".vai/mnist_ratio_0.5.spec")
您可使用 Keras 模型保存 API 保存精简模型并重新加载此模型用于推断或量化。例如:
slim_model.save('/tmp/model')
loaded_model = tf.keras.models.load_model('/tmp/model')