Dramb commited on
Commit
2d2347f
·
verified ·
1 Parent(s): f8a743c

Create medsam2_model.py

Browse files
Files changed (1) hide show
  1. medsam2_model.py +32 -0
medsam2_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from skimage import transform
4
+ # from sam2_train.build_sam import build_sam2
5
+ # from sam2_train.sam2_image_predictor import SAM2ImagePredictor
6
+ from sam2.build_sam import build_sam2
7
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
8
+
9
+ class MedSAM2:
10
+ def __init__(self, model_path, device="cpu"):
11
+ self.device = device
12
+ self.model = build_sam2("sam2_hiera_t", model_path, device=device)
13
+ self.predictor = SAM2ImagePredictor(self.model)
14
+
15
+ def predict(self, image: np.ndarray, box: list[float]) -> np.ndarray:
16
+ image_3c = image if image.shape[2] == 3 else np.repeat(image[:, :, None], 3, axis=-1)
17
+ img_1024 = transform.resize(image_3c, (1024, 1024), preserve_range=True).astype(np.uint8)
18
+
19
+ box_np = np.array(box)
20
+ box_1024 = box_np / np.array([image.shape[1], image.shape[0], image.shape[1], image.shape[0]]) * 1024
21
+ box_1024 = box_1024[None, :]
22
+
23
+ with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
24
+ self.predictor.set_image(img_1024)
25
+ masks, _, _ = self.predictor.predict(
26
+ point_coords=None,
27
+ point_labels=None,
28
+ box=box_1024,
29
+ multimask_output=False
30
+ )
31
+
32
+ return masks[0].astype(np.uint8)