|
import onnxruntime as ort |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from PIL import Image |
|
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM |
|
from typing import Union, List, Dict, Any |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class InferenceEngine: |
|
def __init__(self, model_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"): |
|
""" |
|
Initialize the InferenceEngine. |
|
|
|
Args: |
|
model_path (str): Path to the ONNX model. |
|
device (str): Device to run the model on ("cuda" or "cpu"). |
|
""" |
|
self.device = device |
|
try: |
|
|
|
self.session = ort.InferenceSession( |
|
model_path, |
|
providers=[ |
|
"TensorrtExecutionProvider", |
|
"CUDAExecutionProvider", |
|
"CPUExecutionProvider" |
|
] |
|
) |
|
logger.info(f"ONNX model loaded successfully on device: {self.device}") |
|
except Exception as e: |
|
logger.error(f"Failed to load ONNX model: {e}") |
|
raise |
|
|
|
def run_text_inference(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str, max_length: int = 200) -> str: |
|
""" |
|
Run text inference using a causal language model. |
|
|
|
Args: |
|
model (AutoModelForCausalLM): Pre-trained causal language model. |
|
tokenizer (AutoTokenizer): Tokenizer for the model. |
|
prompt (str): Input text prompt. |
|
max_length (int): Maximum length of the generated text. |
|
|
|
Returns: |
|
str: Generated text. |
|
""" |
|
try: |
|
inputs = tokenizer(prompt, return_tensors="pt").to(self.device) |
|
outputs = model.generate(**inputs, max_length=max_length) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
except Exception as e: |
|
logger.error(f"Text inference failed: {e}") |
|
raise |
|
|
|
def run_image_inference(self, clip_model: CLIPModel, processor: CLIPProcessor, image_path: str) -> np.ndarray: |
|
""" |
|
Run image inference using a CLIP model. |
|
|
|
Args: |
|
clip_model (CLIPModel): Pre-trained CLIP model. |
|
processor (CLIPProcessor): Processor for the CLIP model. |
|
image_path (str): Path to the input image. |
|
|
|
Returns: |
|
np.ndarray: Image features as a numpy array. |
|
""" |
|
try: |
|
image = Image.open(image_path).convert("RGB") |
|
inputs = processor(images=image, return_tensors="pt").to(self.device) |
|
outputs = clip_model.get_image_features(**inputs) |
|
return outputs.cpu().detach().numpy() |
|
except Exception as e: |
|
logger.error(f"Image inference failed: {e}") |
|
raise |
|
|
|
def run_audio_inference(self, whisper_model: Any, audio_file: str) -> str: |
|
""" |
|
Run audio inference using a Whisper model. |
|
|
|
Args: |
|
whisper_model (Any): Pre-trained Whisper model. |
|
audio_file (str): Path to the input audio file. |
|
|
|
Returns: |
|
str: Transcribed text. |
|
""" |
|
try: |
|
waveform, sample_rate = torchaudio.load(audio_file) |
|
waveform = waveform.to(self.device) |
|
return whisper_model.transcribe(waveform)["text"] |
|
except Exception as e: |
|
logger.error(f"Audio inference failed: {e}") |
|
raise |
|
|
|
def run_general_inference(self, input_data: Union[np.ndarray, List, Dict]) -> np.ndarray: |
|
""" |
|
Run general inference using the ONNX model. |
|
|
|
Args: |
|
input_data (Union[np.ndarray, List, Dict]): Input data for the model. |
|
|
|
Returns: |
|
np.ndarray: Model output. |
|
""" |
|
try: |
|
input_name = self.session.get_inputs()[0].name |
|
output_name = self.session.get_outputs()[0].name |
|
|
|
|
|
if not isinstance(input_data, np.ndarray): |
|
input_data = np.array(input_data, dtype=np.float32) |
|
|
|
return self.session.run([output_name], {input_name: input_data})[0] |
|
except Exception as e: |
|
logger.error(f"General inference failed: {e}") |
|
raise |