pikhan's picture
uploading some code files
b5f1696 verified
raw
history blame contribute delete
402 Bytes
import torch
from transformers import GPT2LMHeadModel
# Load your pre-trained model
model = GPT2LMHeadModel.from_pretrained('./saved_gpt2_medium_nice_model_directory')
model.eval()
# Dummy input for the export
dummy_input = torch.randint(0, 50256, (1, 10))
# Export the model
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, input_names=['input_ids'], output_names=['outputs'])