import os os.system("pip install ./MultiScaleDeformableAttention-1.0-py3-none-any.whl") import gradio as gr from huggingface_hub import hf_hub_download import numpy as np import numpy as np import torch import matplotlib.pyplot as plt import cv2 from PIL import Image import torch.nn as nn from torch.autograd import Variable from torchvision import transforms import torch.nn.functional as F import gdown import os from io import BytesIO from IS_Net.data_loader import normalize, im_reader, im_preprocess from IS_Net.models.isnet import ISNetGTEncoder, ISNetDIS from SAM.segment_anything import sam_model_registry, SamPredictor device = "cuda" if torch.cuda.is_available() else "cpu" def show_gray_images(images, m=8, alpha=3): n, h, w = images.shape num_rows = (n + m - 1) // m fig, axes = plt.subplots(num_rows, m, figsize=(m * 2*alpha, num_rows * 2*alpha)) plt.subplots_adjust(wspace=0.05, hspace=0.05) for i in range(num_rows): for j in range(m): idx = i*m + j if m == 1 or num_rows == 1: axes[idx].imshow(images[idx], cmap='gray') axes[idx].axis('off') elif idx < n: axes[i, j].imshow(images[idx], cmap='gray') axes[i, j].axis('off') plt.show() def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) sam_checkpoint = hf_hub_download(repo_id="andzhang01/segment_anything", filename="sam_vit_l_0b3195.pth") # sam_checkpoint = r"~/.cache/huggingface/hub/models--andzhang01--segment-anything/sam_vit_l_0b3195.pth" model_type = "vit_l" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint, device=device) sam.to(device=device) predictor = SamPredictor(sam) class GOSNormalize(object): ''' Normalize the Image using torch.transforms ''' def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]): self.mean = mean self.std = std def __call__(self,image): image = normalize(image,self.mean,self.std) return image transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0])]) def build_model(hypar,device): net = hypar["model"]#GOSNETINC(3,1) # convert to half precision if(hypar["model_digit"]=="half"): net.half() for layer in net.modules(): if isinstance(layer, nn.BatchNorm2d): layer.float() net.to(device) if(hypar["restore_model"]!=""): net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location=device)) net.to(device) net.eval() return net def get_box(input_box,size): # 初始化一个全零的图像 image = torch.zeros(size) # 填充方框区域为白色(值为255) image[input_box[1]:input_box[3],input_box[0]:input_box[2]] = 255 return image def get_box_from_mask(gt): gt = torch.from_numpy(np.array(gt)) box = torch.zeros_like(gt)+gt box = box.float() rows, cols = torch.where(box>0) left = torch.min(cols) top = torch.min(rows) right = torch.max(cols) bottom = torch.max(rows) box[top:bottom,left:right] = 255 box[box!=255] = 0 return box def predict_one(net, image, mask, box, transforms, hypar, device): ''' Given an Image, predict the mask ''' with torch.no_grad(): image = torch.from_numpy(np.array(image)) mask = torch.from_numpy(np.array(mask)) box = torch.from_numpy(np.array(box)) if mask.max()==1: mask = mask.type(torch.float32)*255.0 # for i in [image,mask[...,None],box[...,None]]: # print(i.shape) inputs_val_v = torch.cat([image,mask[...,None],box[...,None]],dim=2) inputs_val_v = inputs_val_v.permute(2,0,1)[None,...] shapes_val = inputs_val_v.shape[-2:] inputs_val_v = F.upsample(inputs_val_v,(hypar["input_size"]),mode='bilinear') box = inputs_val_v[0][-1] box[box>127] = 255 box[box<=127] = 0 inputs_val_v[0][-1] = box # plt.imshow(inputs_val_v[0][-1]) # plt.show() inputs_val_v = inputs_val_v.divide(255.0) # print(shapes_val) net.eval() if(hypar["model_digit"]=="full"): inputs_val_v = inputs_val_v.type(torch.FloatTensor) else: inputs_val_v = inputs_val_v.type(torch.HalfTensor) inputs_val_v = Variable(inputs_val_v, requires_grad=False).to(device) # wrap inputs in Variable inputs_val_v = transforms(inputs_val_v) # print(inputs_val_v.shape) ds_val = net(inputs_val_v)[0][0] # print(ds_val.shape) ## recover the prediction spatial size to the orignal image size pred_val = F.upsample(ds_val,(shapes_val),mode='bilinear')[0][0] # print(pred_val.shape) ma = torch.max(pred_val) mi = torch.min(pred_val) pred_val = (pred_val-mi)/(ma-mi) # max = 1 if device == 'cuda': torch.cuda.empty_cache() refined_mask = (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # refined_mask[refined_mask>127] = 255 # refined_mask[refined_mask<=127] = 0 # refined_mask = 1 - refined_mask.astype(np.byte) ret, binary = cv2.threshold(refined_mask, 0, 255, cv2.THRESH_OTSU) return binary# it is the mask we need hypar = {} # paramters for inferencing dis_model_path = hf_hub_download(repo_id="jwlarocque/DIS-SAM", filename="DIS-SAM-checkpoint.pth") # hypar["model_path"] ="~/.cache/huggingface/hub/jwlarocque/DIS-SAM" hypar["model_path"] = os.path.split(dis_model_path)[0] # hypar["restore_model"] = "DIS-SAM-checkpoint.pth" hypar["restore_model"] = os.path.split(dis_model_path)[1] hypar["model_digit"] = "full" hypar["input_size"] = [1024, 1024] hypar["model"] = ISNetDIS(in_ch=5) net = build_model(hypar, device) def bbox_from_str(bbox_str: str): if not bbox_str: return None split = bbox_str.strip().split(",") if len(split) == 4: try: bbox = [int(x) for x in split] return np.array(bbox) except ValueError: return None else: return None def predict(input_img: np.ndarray, bbox_str: str): predictor.set_image(input_img) input_label = np.array([1]) bbox = bbox_from_str(bbox_str) input_box = bbox if bbox is not None else np.array([0, 0, input_img.shape[1], input_img.shape[0]]) masks, scores, logits = predictor.predict( box=input_box, point_labels=input_label, multimask_output=True, ) mask = masks[0] DIS_mask = mask DIS_box = get_box_from_mask(DIS_mask) refined_mask = predict_one(net,input_img,DIS_mask,DIS_box,transform,hypar,device) mask_gray = (mask * 255).astype(np.uint8) refined_mask_gray = refined_mask.astype(np.uint8) return mask_gray, refined_mask_gray gradio_app = gr.Interface( predict, inputs=[ gr.Image(label="Select Image", sources=['upload', 'webcam'], type="numpy"), gr.Textbox(label="Bounding Box Prompt (pixels)", placeholder="x1,y1,x2,y2")], outputs=[gr.Image(label="SAM Mask", type="numpy", image_mode="L"), gr.Image(label="DIS-SAM Mask", type="numpy", image_mode="L")], title="DIS-SAM", examples=[ ["./images/wire_shelf.jpg", "20,100,480,660"], ["./images/radio_telescope.jpg", "1130,320,4000,2920"], ["./images/bridge.jpg", ""], ["./images/tree.jpg", "70,110,2290,1800"], ["./images/bicycle.jpg", "135,235,2425,1580"], ["./images/capybara.jpg", "630,440,2060,1650"], ["./images/capybara.jpg", "1050,173,1550,618"] ] ) if __name__ == "__main__": gradio_app.launch()