Here is an example of using the new zentorch.llm.optimize() method in
BFloat16.
For this example, you will need a Hugging Face token and configure accordingly in the code snippet below. If you are using Python 3.9 for the following example, you may encounter a huggingface/tokenizers warning; to disable it please set the following environment variable:
export TOKENIZERS_PARALLELISM=false
import torch
import zentorch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load Tokenizer and Model
model_id = "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(
model_id,
torchscript=True,
return_dict=False,
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = model.eval()
# Prepare Inputs
generate_kwargs = dict(
do_sample=False,
num_beams=4,
max_new_tokens=10,
min_new_tokens=2,
)
prompt = "Hi, How are you today?"
# Inference
############### Code modification ###############
model = zentorch.llm.optimize(model, dtype=torch.bfloat16)
#################################################
with torch.inference_mode(), torch.no_grad(), torch.amp.autocast('cpu', enabled=True):
############### Code modification ###############
model.forward = torch.compile(model.forward, backend="zentorch")
#################################################
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids, **generate_kwargs)
gen_text = tokenizer.batch_decode(output, skip_special_tokens=True)
print(gen_text)
Sample output
Hi, How are you today? I hope you are having a great day.
ZenDNN supports INT4 weight-only quantization with BFloat16 activations (W4A16). Here is an example of how to load a pre-quantized INT4 model from Hugging Face. This model has been quantized using AMD Quark version 0.8. For detailed instructions on using the tool, please refer to the AMD Quark documentation.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import zentorch
# Load Tokenizer and Model
model_id = "meta-llama/Llama-3.1-8B"
config = AutoConfig.from_pretrained(
model_id,
torchscript=True,
return_dict=False,
torch_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
# Load WOQ model
############### Code modification ###############
safetensor_path = "<Path to Quantized Model"
model = zentorch.load_quantized_model(model, safetensor_path)
model = model.eval()
#################################################
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, padding_side="left",use_fast=False)
# Prepare Inputs
generate_kwargs = dict(
do_sample=False,
num_beams=4,
max_new_tokens=10,
min_new_tokens=2,
)
prompt = "Hi, How are you today?"
# Inference
############### Code modification ###############
model = zentorch.llm.optimize(model, dtype=torch.bfloat16)
#################################################
with torch.inference_mode(), torch.no_grad(), torch.amp.autocast('cpu', enabled=True):
############### Code modification ###############
model.forward = torch.compile(model.forward, backend="zentorch")
#################################################
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids, **generate_kwargs)
gen_text = tokenizer.batch_decode(output, skip_special_tokens=True)
print(gen_text)
Sample Output
Hi, How are you today? I hope you are having a great day.