|
import torch |
|
import onnxruntime as ort |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from PIL import Image |
|
import numpy as np |
|
|
|
class YOLOConfig(PretrainedConfig): |
|
model_type = "yolo-segmentation" |
|
|
|
class YOLOTransformersModel(PreTrainedModel): |
|
config_class = YOLOConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.session = ort.InferenceSession(config.onnx_model, providers=["CPUExecutionProvider"]) |
|
|
|
def forward(self, images): |
|
input_array = np.array(images.convert("RGB")).astype(np.float32) |
|
input_array = np.expand_dims(input_array, axis=0) |
|
|
|
outputs = self.session.run(None, {"images": input_array}) |
|
return outputs |