astro-seg / modelling_yolo.py
rayh's picture
Ok, try this
933bb62 unverified
raw
history blame
779 Bytes
import torch
import onnxruntime as ort
from transformers import PreTrainedModel, PretrainedConfig
from PIL import Image
import numpy as np
class YOLOConfig(PretrainedConfig):
model_type = "yolo-segmentation"
class YOLOTransformersModel(PreTrainedModel):
config_class = YOLOConfig
def __init__(self, config):
super().__init__(config)
self.session = ort.InferenceSession(config.onnx_model, providers=["CPUExecutionProvider"])
def forward(self, images):
input_array = np.array(images.convert("RGB")).astype(np.float32)
input_array = np.expand_dims(input_array, axis=0) # Add batch dimension
outputs = self.session.run(None, {"images": input_array})
return outputs # Modify as needed to match Transformers format