vai_q_tensorflow 量化感知训练 - 3.5 简体中文

Vitis AI 用户指南 (UG1414)

Document ID
UG1414
Release Date
2023-09-28
Version
3.5 简体中文
量化感知训练 (QAT) 与浮点模型训练/微调类似。但在 QAT 中,vai_q_tensorflow API 会在训练开始前将浮点计算图转换为量化计算图。典型工作流程如下:
  1. 准备:开始 QAT 前,准备下列文件:
    表 1. vai_q_tensorflow QAT 的输入文件
    编号 名称 描述
    1 检查点文件 作为起点的浮点检查点文件。如果要从头开始训练模型,请忽略此文件。
    2 数据集 含标记的训练数据集。
    3 训练脚本 用于运行模型的浮点训练/微调的 Python 脚本。
  2. 评估浮点模型(可选):在执行量化微调之前,评估浮点检查点文件以检查脚本和数据集的准确性。浮点检查点的精度和损失值也可作为 QAT 的基线。
  3. 修改训练脚本:要创建量化训练计算图,请在构建浮点计算图后,修改训练脚本以调用函数。下面给出了 1 个示例:
    # train.py
    						
    # ...
    						
    # Create the float training graph
    model = model_fn(is_training=True)
    						
    # *Set the quantize configurations
    import vai_q_tensorflow
    q_config = vai_q_tensorflow.QuantizeConfig(input_nodes=['net_in'],
    				output_nodes=['net_out'], 
    				input_shapes=[[-1, 224, 224, 3]])
    # *Call Vai_q_tensorflow API to create the quantize training graph
    						vai_q_tensorflow.CreateQuantizeTrainingGraph(config=q_config)
    						
    # Create the optimizer 
    optimizer = tf.train.GradientDescentOptimizer()
    						
    # start the training/finetuning; you can use sess.run(), tf.train, tf.estimator, tf.slim and so on
    # ...
    注释: 您可使用 import vai_q_tensorflow as decent_q 来保证与 vai_q_tensorflow 的更低版本的代码(即,import tensorflow.contrib.decent_q)兼容

    QuantizeConfig 包含量化的配置。

    部分基本配置(如,input_nodesoutput_nodesand input_shapes)必须根据模型结构加以设置。

    其他配置(如 weight_bitactivation_bitmethod)则包含默认值,可按需修改。请参阅 vai_q_tensorflow 用法 以获取所有配置的详细信息。

    input_nodes/output_nodes
    这两项结合使用,即可判定要量化的子计算图范围。预处理和后处理组件通常不可量化,应置于此范围之外。input_nodes 与 output_nodes 应相同,以使浮点训练计算图与评估计算图之间的量化运算相匹配。
    注释: 当前不支持含多个输出张量(例如,FIFO)的运算。可添加 tf.identity 节点为 input_tensor 生成别名,以构成含单一输出的输入节点。
    input_shapes
    对于每个节点,input_nodes 的形状列表必须为 4 维。该信息以逗号分隔,例如,[[1,224,224,3] [1, 128, 128, 1]];支持 batch_size 为未知大小,例如,[[-1,224,224,3]]。
  4. 评估并生成量化模型:在 QAT 之后,使用检查点文件评估量化计算图,并生成冻结模型。方法是在构建浮点评估计算图之后调用以下函数。冻结进程取决于量化评估计算图,因此通常两者一起调用。
    注释: vai_q_tensorflow.CreateQuantizeTrainingGraph 函数和 vai_q_tensorflow.CreateQuantizeEvaluationGraph 函数用于修改 Tensorflow 中的默认计算图。这两个函数必须在不同计算图阶段内分别调用。vai_q_tensorflow.CreateQuantizeTrainingGraph 必须在浮点训练计算图上调用,而 vai_q_tensorflow.CreateQuantizeEvaluationGraph 则需在浮点评估计算图上调用。调用 vai_q_tensorflow.CreateQuantizeTrainingGraph 函数后就无法再调用 vai_q_tensorflow.CreateQuantizeEvaluationGraph,因为默认计算图已转换为量化训练计算图。正确的方法是在调用浮点模型创建函数后立即调用该函数。
    # eval.py
    						
    # ...
    						
    # Create the float evaluation graph
    model = model_fn(is_training=False)
    						
    # *Set the quantize configurations
    import vai_q_tensorflow
    q_config = vai_q_tensorflow.QuantizeConfig(input_nodes=['net_in'],
    				output_nodes=['net_out'], 
    				input_shapes=[[-1, 224, 224, 3]])
    # *Call Vai_q_tensorflow API to create the quantize evaluation graph
    						vai_q_tensorflow.CreateQuantizeEvaluationGraph(config=q_config)
    # *Call Vai_q_tensorflow API to freeze the model and generate the deploy model
    						vai_q_tensorflow.CreateQuantizeDeployGraph(checkpoint="path to checkpoint folder", config=q_config)
    						
    # start the evaluation; You can use sess.run, tf.train, tf.estimator, tf.slim and so on
    # ...