Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import numpy as np | |
import torch | |
import matplotlib.pyplot as plt | |
import cv2 | |
from PIL import Image | |
import torch.nn as nn | |
from torch.autograd import Variable | |
from torchvision import transforms | |
import torch.nn.functional as F | |
import gdown | |
import os | |
from io import BytesIO | |
from IS_Net.data_loader import normalize, im_reader, im_preprocess | |
from IS_Net.models.isnet import ISNetGTEncoder, ISNetDIS | |
from SAM.segment_anything import sam_model_registry, SamPredictor | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def show_gray_images(images, m=8, alpha=3): | |
n, h, w = images.shape | |
num_rows = (n + m - 1) // m | |
fig, axes = plt.subplots(num_rows, m, figsize=(m * 2*alpha, num_rows * 2*alpha)) | |
plt.subplots_adjust(wspace=0.05, hspace=0.05) | |
for i in range(num_rows): | |
for j in range(m): | |
idx = i*m + j | |
if m == 1 or num_rows == 1: | |
axes[idx].imshow(images[idx], cmap='gray') | |
axes[idx].axis('off') | |
elif idx < n: | |
axes[i, j].imshow(images[idx], cmap='gray') | |
axes[i, j].axis('off') | |
plt.show() | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
sam_checkpoint = r"~/.cache/huggingface/hub/sam_vit_l_0b3195.pth" | |
model_type = "vit_l" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint, device=device) | |
sam.to(device=device) | |
predictor = SamPredictor(sam) | |
class GOSNormalize(object): | |
''' | |
Normalize the Image using torch.transforms | |
''' | |
def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]): | |
self.mean = mean | |
self.std = std | |
def __call__(self,image): | |
image = normalize(image,self.mean,self.std) | |
return image | |
transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0])]) | |
def build_model(hypar,device): | |
net = hypar["model"]#GOSNETINC(3,1) | |
# convert to half precision | |
if(hypar["model_digit"]=="half"): | |
net.half() | |
for layer in net.modules(): | |
if isinstance(layer, nn.BatchNorm2d): | |
layer.float() | |
net.to(device) | |
if(hypar["restore_model"]!=""): | |
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location=device)) | |
net.to(device) | |
net.eval() | |
return net | |
def get_box(input_box,size): | |
# 初始化一个全零的图像 | |
image = torch.zeros(size) | |
# 填充方框区域为白色(值为255) | |
image[input_box[1]:input_box[3],input_box[0]:input_box[2]] = 255 | |
return image | |
def get_box_from_mask(gt): | |
gt = torch.from_numpy(np.array(gt)) | |
box = torch.zeros_like(gt)+gt | |
box = box.float() | |
rows, cols = torch.where(box>0) | |
left = torch.min(cols) | |
top = torch.min(rows) | |
right = torch.max(cols) | |
bottom = torch.max(rows) | |
box[top:bottom,left:right] = 255 | |
box[box!=255] = 0 | |
return box | |
def predict_one(net, image, mask, box, transforms, hypar, device): | |
''' | |
Given an Image, predict the mask | |
''' | |
with torch.no_grad(): | |
image = torch.from_numpy(np.array(image)) | |
mask = torch.from_numpy(np.array(mask)) | |
box = torch.from_numpy(np.array(box)) | |
if mask.max()==1: | |
mask = mask.type(torch.float32)*255.0 | |
# for i in [image,mask[...,None],box[...,None]]: | |
# print(i.shape) | |
inputs_val_v = torch.cat([image,mask[...,None],box[...,None]],dim=2) | |
inputs_val_v = inputs_val_v.permute(2,0,1)[None,...] | |
shapes_val = inputs_val_v.shape[-2:] | |
inputs_val_v = F.upsample(inputs_val_v,(hypar["input_size"]),mode='bilinear') | |
box = inputs_val_v[0][-1] | |
box[box>127] = 255 | |
box[box<=127] = 0 | |
inputs_val_v[0][-1] = box | |
# plt.imshow(inputs_val_v[0][-1]) | |
# plt.show() | |
inputs_val_v = inputs_val_v.divide(255.0) | |
# print(shapes_val) | |
net.eval() | |
if(hypar["model_digit"]=="full"): | |
inputs_val_v = inputs_val_v.type(torch.FloatTensor) | |
else: | |
inputs_val_v = inputs_val_v.type(torch.HalfTensor) | |
inputs_val_v = Variable(inputs_val_v, requires_grad=False).to(device) # wrap inputs in Variable | |
inputs_val_v = transforms(inputs_val_v) | |
# print(inputs_val_v.shape) | |
ds_val = net(inputs_val_v)[0][0] | |
# print(ds_val.shape) | |
## recover the prediction spatial size to the orignal image size | |
pred_val = F.upsample(ds_val,(shapes_val),mode='bilinear')[0][0] | |
# print(pred_val.shape) | |
ma = torch.max(pred_val) | |
mi = torch.min(pred_val) | |
pred_val = (pred_val-mi)/(ma-mi) # max = 1 | |
if device == 'cuda': torch.cuda.empty_cache() | |
refined_mask = (pred_val.detach().cpu().numpy()*255).astype(np.uint8) | |
# refined_mask[refined_mask>127] = 255 | |
# refined_mask[refined_mask<=127] = 0 | |
# refined_mask = 1 - refined_mask.astype(np.byte) | |
ret, binary = cv2.threshold(refined_mask, 0, 255, cv2.THRESH_OTSU) | |
return binary# it is the mask we need | |
hypar = {} # paramters for inferencing | |
hypar["model_path"] ="~/.cache/huggingface/hub" | |
hypar["restore_model"] = "DIS-SAM-checkpoint.pth" | |
hypar["model_digit"] = "full" | |
hypar["input_size"] = [1024, 1024] | |
hypar["model"] = ISNetDIS(in_ch=5) | |
net = build_model(hypar, device) | |
def bbox_from_str(bbox_str: str): | |
if not bbox_str: | |
return None | |
split = bbox_str.strip().split(",") | |
if len(split) == 4: | |
try: | |
bbox = [int(x) for x in split] | |
return np.array(bbox) | |
except ValueError: | |
return None | |
else: | |
return None | |
def predict(input_img: np.ndarray, bbox_str: str): | |
predictor.set_image(input_img) | |
input_label = np.array([1]) | |
bbox = bbox_from_str(bbox_str) | |
input_box = bbox if bbox is not None else np.array([0, 0, input_img.shape[1], input_img.shape[0]]) | |
masks, scores, logits = predictor.predict( | |
box=input_box, | |
point_labels=input_label, | |
multimask_output=True, | |
) | |
mask = masks[0] | |
DIS_mask = mask | |
DIS_box = get_box_from_mask(DIS_mask) | |
refined_mask = predict_one(net,input_img,DIS_mask,DIS_box,transform,hypar,device) | |
mask_gray = (mask * 255).astype(np.uint8) | |
refined_mask_gray = refined_mask.astype(np.uint8) | |
return mask_gray, refined_mask_gray | |
gradio_app = gr.Interface( | |
predict, | |
inputs=[ | |
gr.Image(label="Select Image", sources=['upload', 'webcam'], type="numpy"), | |
gr.Textbox(label="Bounding Box Prompt (pixels)", placeholder="x1,y1,x2,y2")], | |
outputs=[gr.Image(label="SAM Mask", type="numpy", image_mode="L"), gr.Image(label="DIS-SAM Mask", type="numpy", image_mode="L")], | |
title="DIS-SAM", | |
examples=[ | |
["./images/wire_shelf.jpg", "20,100,480,660"], | |
["./images/radio_telescope.jpg", "1130,320,4000,2920"], | |
["./images/bridge.jpg", ""], | |
["./images/tree.jpg", "70,110,2290,1800"] | |
] | |
) | |
if __name__ == "__main__": | |
gradio_app.launch() | |