BERT-based Models - 57300

ZenDNN User Guide (57300)

Document ID
57300
Release Date
2025-08-18
Revision
5.1 English
import torch
import zentorch 
from transformers import BertTokenizer, BertModel
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("imdb", split="test")

print(dataset[0]['text'])

# Load the tokenizer and the model
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', trust_remote_code=True)

# Load the model
model_id = "google-bert/bert-large-uncased"
model = BertModel.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        )
model = model.eval()

############### Code modification ###############
model.forward = torch.compile(model.forward, backend="zentorch")
#################################################

# Inference
with torch.inference_mode(), torch.no_grad():
    # Prepare inputs by tokenizing the examples
    inputs = tokenizer(dataset['text'][:3], return_tensors="pt", padding=True, truncation=True)

    # Generate outputs
    outputs = model(**inputs)

# Get last hidden states
last_hidden_states = outputs.last_hidden_state

# Print the shape of the last hidden states
	print("Last hidden states shape:", last_hidden_states.shape)

Sample Output

Last hidden states shape: torch.Size([3, 339, 1024])