此函数用于获取 QAT 的浮点模型:
vitis_quantize.VitisQuantizer.get_qat_model(
init_quant=False,
calib_dataset=None,
calib_batch_size=None,
calib_steps=None,
train_with_bn=False,
freeze_bn_delay=-1)
实参
- init_quant
-
bool
对象,用于通知是否在 QAT 之前运行初始量化。运行初始 PTQ 量化可为量化器参数产生更好的初始状态,对于 8bit_tqt 策略尤其如此。否则,训练可能不会收敛。 - calib_dataset
-
tf.data.Dataset
、keras.utils.Sequence
或np.numpy
对象。用于校准的代表性数据集。当 init_quant 设置为True
时必须设置此参数。您可以将 eval_dataset、train_dataset 或其他数据集整体或其中一部分用作 calib_dataset。 - calib_steps
-
int
对象。表示初始 PTQ 步骤总数。可忽略,默认值为 None。如果 calib_dataset 为tf.data dataset
、生成器或keras.utils.Sequence
实例,且 steps 为 None,校准会运行到数据集耗尽为止。阵列输入不支持此实参。 - calib_batch_size
-
int
对象。表示初始 PTQ 的每批次样本数。如果“calib_dataset”为数据集、生成器或keras.utils.Sequence
实例形式,则批次大小由数据集本身控制。如果 calib_dataset 为numpy.array
对象形式,则默认批次大小为 32。 - train_with_bn
-
bool
对象。指示在 QAT 期间是否保留 bn 层。如果设为 True,bn 参数会在量化感知训练期间更新,并帮助模型收敛。然后这些经过训练的 bn 层融合到 get_deploy_model() 函数中先前的类卷积层。如果浮点模型没有 bn 层,则此选项无效。默认值为 false。 - freeze_bn_delay
-
int
对象。冻结 bn 参数前的训练步骤。在延迟步骤后,模型会切换推断 bn 参数以避免训练中出现不稳定。仅当 train_with_bn 为 True 时才会生效。默认值为 -1,表示从不执行 bn 冻结。