QAT is similar to the float model training/finetuning except that vai_q_tensorflow2 rewrites the float graph to convert it to a quantized model before the training starts. The typical workflow is as follows. You can find a complete example here.
-
Preparing the float model, dataset, and training scripts:
Before QAT, prepare the following files:
Table 1. Input Files for vai_q_tensorflow2 QAT No. Name Description 1 Float model Floating-point model files to start from. Can be omitted if training from scratch. 2 Dataset The training dataset with labels. 3 Training Scripts The Python scripts to run float train/finetuning of the model. -
(Optional) Evaluate the float model.
Evaluate the float model first before QAT to check the correctness of the scripts and dataset. The accuracy and loss values of the float checkpoint can also be a baseline for QAT.
-
Modify the training scripts and run QAT.
Use the vai_q_tensorflow2 API,
VitisQuantizer.get_qat_model
, to convert the model to a quantized model and then proceed to training/finetuning with it. The following is an example:model = tf.keras.models.load_model(‘float_model.h5’) # *Call Vai_q_tensorflow2 api to create the quantize training model from tensorflow_model_optimization.quantization.keras import vitis_quantize quantizer = vitis_quantize.VitisQuantizer(model, quantize_strategy='8bit_tqt') qat_model = quantizer.get_qat_model( init_quant=True, # Do init PTQ quantization will help us to get a better initial state for the quantizers, especially for `8bit_tqt` strategy. Must be used together with calib_dataset calib_dataset=calib_dataset) # Then run the training process with this qat_model to get the quantize finetuned model. # Compile the model model.compile( optimizer= RMSprop(learning_rate=lr_schedule), loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=keras.metrics.SparseTopKCategoricalAccuracy()) # Start the training/finetuning model.fit(train_dataset)
Note: Vitis AI 1.4 supports 8bit_tqt. It uses trained threshold in quantizers and may result in better results for QAT. By default, the Straight-Through-Estimator is used. 8bit_tqt strategy should only be used in QAT with'init_quant=True'
to get best performance. Initialization with PTQ quantization can generate a better initial state for quantizer parameters, especially for 8bit_tqt. Otherwise, the training may not converge. -
Save the model.
Call
model.save()
to save the trained model or use callbacks inmodel.fit()
to save the model periodically. For example:# save model manually model.save(‘trained_model.h5’) # save the model periodically during fit using callbacks model.fit( train_dataset, callbacks = [ keras.callbacks.ModelCheckpoint( filepath=’./quantize_train/’ save_best_only=True, monitor="sparse_categorical_accuracy", verbose=1, )])
-
Convert to deployable quantized model.
Modify the trained/finetuned model to meet the compiler requirements. For example, if "train_with_bn" is set to TRUE, it means that the bn layers and the dropout layers are not folded during training and must be folded before deployment. Some of the quantizer parameters may vary during training and exceed the compiler permitted ranges. These must be corrected before deployment.
A
get_deploy_model()
function is provided to perform these conversions and generate a deployable model as shown in the following example.quantized_model = vitis_quantizer.get_deploy_model(model) quantized_model.save('quantized_model.h5')
-
(Optional) Evaluate the quantized model
Call
model.evaluate()
on theeval_dataset
to evaluate the quantized model, just like evaluation of the float model.from tensorflow_model_optimization.quantization.keras import vitis_quantize quantized_model = tf.keras.models.load_model('quantized_model.h5') quantized_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics= keras.metrics.SparseTopKCategoricalAccuracy()) quantized_model.evaluate(eval_dataset)
Recommended: Use the float model training and finetuning before proceeding to QAT.