Generally, quantization might lead to a slight accuracy loss in the model. However, for specific networks like MobileNets, the accuracy loss can be more significant. To address this, Quantization Aware Training (QAT) offers a solution to enhance the accuracy of quantized models further.
QAT is similar to training/finetuning floating-point models, except that vai_q_tensorflow2 rewrites the float graph to convert it into a quantized model before the training begins. You can find a complete example here.
The typical workflow for QAT is as follows:
- Preparing the float model, dataset, and training scripts:
Before QAT, prepare the following files:
Table 1. Input Files for vai_q_tensorflow2 QAT No. Name Description 1 Float model Floating-point model files from which to start. You can ignore this if you are training from the scratch. 2 Dataset The training dataset with labels. 3 Training Scripts The Python scripts to run float train/finetuning of the model. - (Optional) Evaluate the float model.
Evaluate the float model before QAT to check the accuracy of the scripts and dataset. The accuracy and loss values of the float checkpoint can also be a baseline for QAT.
- Modify the training scripts and run QAT.
Use the vai_q_tensorflow2 API,
VitisQuantizer.get_qat_model
to convert the model to a quantized model and then proceed to training/finetuning with it. The following is an example:model = tf.keras.models.load_model('float_model.h5') # *Call Vai_q_tensorflow2 api to create the quantize training model from tensorflow_model_optimization.quantization.keras import vitis_quantize quantizer = vitis_quantize.VitisQuantizer(model) qat_model = quantizer.get_qat_model( init_quant=True, # Do init PTQ quantization will help us to get a better initial state for the quantizers, especially for the `pof2s_tqt` strategy. Must be used together with calib_dataset calib_dataset=calib_dataset) # Then run the training process with this qat_model to get the quantize finetuned model. # Compile the model qat_model.compile( optimizer= RMSprop(learning_rate=lr_schedule), loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=keras.metrics.SparseTopKCategoricalAccuracy()) # Start the training/finetuning qat_model.fit(train_dataset)
Note: Vitis AI supports pof2s_tqt quantize strategy from 2.0. It uses trained threshold in quantizers and might result in better results for QAT. By default, the Straight-Through-Estimator is used. 8bit_tqt approach should only be used in QAT with'init_quant=True'
to get the best performance. Initialization with PTQ quantization can generate a better initial state for quantizer parameters, especially for pof2s_tqt. Otherwise, the training might not converge. - Save the model.
Call
model.save()
to save the trained model or use callbacks inmodel.fit()
to save the model periodically. For example:# save model manually qat_model.save('trained_model.h5') # save the model periodically during fit using callbacks qat_model.fit( train_dataset, callbacks = [ keras.callbacks.ModelCheckpoint( filepath='./quantize_train/' save_best_only=True, monitor="sparse_categorical_accuracy", verbose=1, )])
- Convert to a deployable quantized model.
Modify the trained/finetuned model to meet the compiler requirements. For example, if
train_with_bn
is set to TRUE, the batch normalization layers remain unfolded during training and must be folded before deployment. Some quantizer parameters might vary during training and exceed the compiler limitation ranges. These must be corrected before deployment.Use the
get_deploy_model()
function to perform these conversions and generate a deployable model, as shown in the following example:quantized_model = vitis_quantizer.get_deploy_model(qat_model) quantized_model.save('quantized_model.h5')
- (Optional) Evaluate the quantized model
Call
model.evaluate()
on theeval_dataset
to evaluate the quantized model, similar to the evaluation of the float model.from tensorflow_model_optimization.quantization.keras import vitis_quantize quantized_model = tf.keras.models.load_model('quantized_model.h5') quantized_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics= keras.metrics.SparseTopKCategoricalAccuracy()) quantized_model.evaluate(eval_dataset)
Recommended: Use the float model training and finetuning before proceeding to QAT.