import torch
import zentorch
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
# Load the dataset
dataset = load_dataset("imdb", split="test")
print(dataset[0]['text'])
# Load the tokenizer and the model
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', trust_remote_code=True)
# Load the model
model_id = "google-bert/bert-large-uncased"
model = BertModel.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = model.eval()
############### Code modification ###############
model.forward = torch.compile(model.forward, backend="zentorch")
#################################################
# Inference
with torch.inference_mode(), torch.no_grad():
# Prepare inputs by tokenizing the examples
inputs = tokenizer(dataset['text'][:3], return_tensors="pt", padding=True, truncation=True)
# Generate outputs
outputs = model(**inputs)
# Get last hidden states
last_hidden_states = outputs.last_hidden_state
# Print the shape of the last hidden states
print("Last hidden states shape:", last_hidden_states.shape)
Sample Output
Last hidden states shape: torch.Size([3, 339, 1024])