medsam2_oct / inference.py
Dramb's picture
Upload folder using huggingface_hub
c8ed8e3 verified
raw
history blame
1.5 kB
from typing import Dict
import torch
import numpy as np
from PIL import Image
from skimage import transform
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
class PreTrainedModel:
def __init__(self):
self.model = build_sam2(
"sam2_hiera_t",
"MedSAM2_pretrain_10ep_b1_AMD-SD_sam2_hiera_t.pth",
device="cuda" if torch.cuda.is_available() else "cpu"
)
self.predictor = SAM2ImagePredictor(self.model)
def __call__(self, inputs: Dict):
image = Image.open(inputs["image"]).convert("RGB")
box = list(map(float, inputs["box"]))
image_np = np.array(image)
img_3c = image_np if image_np.shape[2] == 3 else np.repeat(image_np[:, :, None], 3, axis=-1)
img_1024 = transform.resize(img_3c, (1024, 1024), preserve_range=True).astype(np.uint8)
box_1024 = np.array(box) / [image_np.shape[1], image_np.shape[0], image_np.shape[1], image_np.shape[0]] * 1024
box_1024 = box_1024[None, :]
with torch.inference_mode(), torch.autocast("cuda" if torch.cuda.is_available() else "cpu", dtype=torch.bfloat16):
self.predictor.set_image(img_1024)
masks, _, _ = self.predictor.predict(
point_coords=None,
point_labels=None,
box=box_1024,
multimask_output=False
)
mask = masks[0].astype(np.uint8)
return {"mask": mask.tolist()}