量化感知训练 (QAT) 与浮点模型训练/微调类似。但在 QAT 中,vai_q_tensorflow API 会在训练开始前将浮点计算图转换为量化计算图。典型工作流程如下:
- 准备:开始 QAT 前,准备下列文件:
表 1. vai_q_tensorflow QAT 的输入文件 编号 名称 描述 1 检查点文件 作为起点的浮点检查点文件。如果要从头开始训练模型,请忽略此文件。 2 数据集 含标记的训练数据集。 3 训练脚本 用于运行模型的浮点训练/微调的 Python 脚本。 - 评估浮点模型(可选):在执行量化微调之前,评估浮点检查点文件以检查脚本和数据集的准确性。浮点检查点的精度和损失值也可作为 QAT 的基线。
- 修改训练脚本:要创建量化训练计算图,请在构建浮点计算图后,修改训练脚本以调用函数。下面给出了 1 个示例:
# train.py # ... # Create the float training graph model = model_fn(is_training=True) # *Set the quantize configurations import vai_q_tensorflow q_config = vai_q_tensorflow.QuantizeConfig(input_nodes=['net_in'], output_nodes=['net_out'], input_shapes=[[-1, 224, 224, 3]]) # *Call Vai_q_tensorflow API to create the quantize training graph vai_q_tensorflow.CreateQuantizeTrainingGraph(config=q_config) # Create the optimizer optimizer = tf.train.GradientDescentOptimizer() # start the training/finetuning; you can use sess.run(), tf.train, tf.estimator, tf.slim and so on # ...
注释: 您可使用import vai_q_tensorflow as decent_q
来保证与 vai_q_tensorflow 的更低版本的代码(即,import tensorflow.contrib.decent_q
)兼容QuantizeConfig
包含量化的配置。部分基本配置(如,
input_nodes
、output_nodes
和and input_shapes
)必须根据模型结构加以设置。其他配置(如
weight_bit
、activation_bit
和method
)则包含默认值,可按需修改。请参阅 vai_q_tensorflow 用法 以获取所有配置的详细信息。-
input_nodes
/output_nodes
- 这两项结合使用,即可判定要量化的子计算图范围。预处理和后处理组件通常不可量化,应置于此范围之外。input_nodes 与 output_nodes 应相同,以使浮点训练计算图与评估计算图之间的量化运算相匹配。 注释: 当前不支持含多个输出张量(例如,FIFO)的运算。可添加 tf.identity 节点为 input_tensor 生成别名,以构成含单一输出的输入节点。
-
input_shapes
- 对于每个节点,input_nodes 的形状列表必须为 4 维。该信息以逗号分隔,例如,[[1,224,224,3] [1, 128, 128, 1]];支持 batch_size 为未知大小,例如,[[-1,224,224,3]]。
-
- 评估并生成量化模型:在 QAT 之后,使用检查点文件评估量化计算图,并生成冻结模型。方法是在构建浮点评估计算图之后调用以下函数。冻结进程取决于量化评估计算图,因此通常两者一起调用。 注释:
vai_q_tensorflow.CreateQuantizeTrainingGraph
函数和vai_q_tensorflow.CreateQuantizeEvaluationGraph
函数用于修改 Tensorflow 中的默认计算图。这两个函数必须在不同计算图阶段内分别调用。vai_q_tensorflow.CreateQuantizeTrainingGraph
必须在浮点训练计算图上调用,而vai_q_tensorflow.CreateQuantizeEvaluationGraph
则需在浮点评估计算图上调用。调用vai_q_tensorflow.CreateQuantizeTrainingGraph
函数后就无法再调用vai_q_tensorflow.CreateQuantizeEvaluationGraph
,因为默认计算图已转换为量化训练计算图。正确的方法是在调用浮点模型创建函数后立即调用该函数。# eval.py # ... # Create the float evaluation graph model = model_fn(is_training=False) # *Set the quantize configurations import vai_q_tensorflow q_config = vai_q_tensorflow.QuantizeConfig(input_nodes=['net_in'], output_nodes=['net_out'], input_shapes=[[-1, 224, 224, 3]]) # *Call Vai_q_tensorflow API to create the quantize evaluation graph vai_q_tensorflow.CreateQuantizeEvaluationGraph(config=q_config) # *Call Vai_q_tensorflow API to freeze the model and generate the deploy model vai_q_tensorflow.CreateQuantizeDeployGraph(checkpoint="path to checkpoint folder", config=q_config) # start the evaluation; You can use sess.run, tf.train, tf.estimator, tf.slim and so on # ...