File size: 402 Bytes
b5f1696
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from transformers import GPT2LMHeadModel

# Load your pre-trained model
model = GPT2LMHeadModel.from_pretrained('./saved_gpt2_medium_nice_model_directory')
model.eval()

# Dummy input for the export
dummy_input = torch.randint(0, 50256, (1, 10))

# Export the model
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, input_names=['input_ids'], output_names=['outputs'])