DIS-SAM / app.py
jwlarocque's picture
Add two prompt example
7d138ae
raw
history blame
8.49 kB
import os
os.system("pip install ./MultiScaleDeformableAttention-1.0-py3-none-any.whl")
import gradio as gr
from huggingface_hub import hf_hub_download
import numpy as np
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import gdown
import os
from io import BytesIO
from IS_Net.data_loader import normalize, im_reader, im_preprocess
from IS_Net.models.isnet import ISNetGTEncoder, ISNetDIS
from SAM.segment_anything import sam_model_registry, SamPredictor
device = "cuda" if torch.cuda.is_available() else "cpu"
def show_gray_images(images, m=8, alpha=3):
n, h, w = images.shape
num_rows = (n + m - 1) // m
fig, axes = plt.subplots(num_rows, m, figsize=(m * 2*alpha, num_rows * 2*alpha))
plt.subplots_adjust(wspace=0.05, hspace=0.05)
for i in range(num_rows):
for j in range(m):
idx = i*m + j
if m == 1 or num_rows == 1:
axes[idx].imshow(images[idx], cmap='gray')
axes[idx].axis('off')
elif idx < n:
axes[i, j].imshow(images[idx], cmap='gray')
axes[i, j].axis('off')
plt.show()
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
sam_checkpoint = hf_hub_download(repo_id="andzhang01/segment_anything", filename="sam_vit_l_0b3195.pth")
# sam_checkpoint = r"~/.cache/huggingface/hub/models--andzhang01--segment-anything/sam_vit_l_0b3195.pth"
model_type = "vit_l"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint, device=device)
sam.to(device=device)
predictor = SamPredictor(sam)
class GOSNormalize(object):
'''
Normalize the Image using torch.transforms
'''
def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]):
self.mean = mean
self.std = std
def __call__(self,image):
image = normalize(image,self.mean,self.std)
return image
transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0])])
def build_model(hypar,device):
net = hypar["model"]#GOSNETINC(3,1)
# convert to half precision
if(hypar["model_digit"]=="half"):
net.half()
for layer in net.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
net.to(device)
if(hypar["restore_model"]!=""):
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location=device))
net.to(device)
net.eval()
return net
def get_box(input_box,size):
# 初始化一个全零的图像
image = torch.zeros(size)
# 填充方框区域为白色(值为255)
image[input_box[1]:input_box[3],input_box[0]:input_box[2]] = 255
return image
def get_box_from_mask(gt):
gt = torch.from_numpy(np.array(gt))
box = torch.zeros_like(gt)+gt
box = box.float()
rows, cols = torch.where(box>0)
left = torch.min(cols)
top = torch.min(rows)
right = torch.max(cols)
bottom = torch.max(rows)
box[top:bottom,left:right] = 255
box[box!=255] = 0
return box
def predict_one(net, image, mask, box, transforms, hypar, device):
'''
Given an Image, predict the mask
'''
with torch.no_grad():
image = torch.from_numpy(np.array(image))
mask = torch.from_numpy(np.array(mask))
box = torch.from_numpy(np.array(box))
if mask.max()==1:
mask = mask.type(torch.float32)*255.0
# for i in [image,mask[...,None],box[...,None]]:
# print(i.shape)
inputs_val_v = torch.cat([image,mask[...,None],box[...,None]],dim=2)
inputs_val_v = inputs_val_v.permute(2,0,1)[None,...]
shapes_val = inputs_val_v.shape[-2:]
inputs_val_v = F.upsample(inputs_val_v,(hypar["input_size"]),mode='bilinear')
box = inputs_val_v[0][-1]
box[box>127] = 255
box[box<=127] = 0
inputs_val_v[0][-1] = box
# plt.imshow(inputs_val_v[0][-1])
# plt.show()
inputs_val_v = inputs_val_v.divide(255.0)
# print(shapes_val)
net.eval()
if(hypar["model_digit"]=="full"):
inputs_val_v = inputs_val_v.type(torch.FloatTensor)
else:
inputs_val_v = inputs_val_v.type(torch.HalfTensor)
inputs_val_v = Variable(inputs_val_v, requires_grad=False).to(device) # wrap inputs in Variable
inputs_val_v = transforms(inputs_val_v)
# print(inputs_val_v.shape)
ds_val = net(inputs_val_v)[0][0]
# print(ds_val.shape)
## recover the prediction spatial size to the orignal image size
pred_val = F.upsample(ds_val,(shapes_val),mode='bilinear')[0][0]
# print(pred_val.shape)
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val-mi)/(ma-mi) # max = 1
if device == 'cuda': torch.cuda.empty_cache()
refined_mask = (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
# refined_mask[refined_mask>127] = 255
# refined_mask[refined_mask<=127] = 0
# refined_mask = 1 - refined_mask.astype(np.byte)
ret, binary = cv2.threshold(refined_mask, 0, 255, cv2.THRESH_OTSU)
return binary# it is the mask we need
hypar = {} # paramters for inferencing
dis_model_path = hf_hub_download(repo_id="jwlarocque/DIS-SAM", filename="DIS-SAM-checkpoint.pth")
# hypar["model_path"] ="~/.cache/huggingface/hub/jwlarocque/DIS-SAM"
hypar["model_path"] = os.path.split(dis_model_path)[0]
# hypar["restore_model"] = "DIS-SAM-checkpoint.pth"
hypar["restore_model"] = os.path.split(dis_model_path)[1]
hypar["model_digit"] = "full"
hypar["input_size"] = [1024, 1024]
hypar["model"] = ISNetDIS(in_ch=5)
net = build_model(hypar, device)
def bbox_from_str(bbox_str: str):
if not bbox_str:
return None
split = bbox_str.strip().split(",")
if len(split) == 4:
try:
bbox = [int(x) for x in split]
return np.array(bbox)
except ValueError:
return None
else:
return None
def predict(input_img: np.ndarray, bbox_str: str):
predictor.set_image(input_img)
input_label = np.array([1])
bbox = bbox_from_str(bbox_str)
input_box = bbox if bbox is not None else np.array([0, 0, input_img.shape[1], input_img.shape[0]])
masks, scores, logits = predictor.predict(
box=input_box,
point_labels=input_label,
multimask_output=True,
)
mask = masks[0]
DIS_mask = mask
DIS_box = get_box_from_mask(DIS_mask)
refined_mask = predict_one(net,input_img,DIS_mask,DIS_box,transform,hypar,device)
mask_gray = (mask * 255).astype(np.uint8)
refined_mask_gray = refined_mask.astype(np.uint8)
return mask_gray, refined_mask_gray
gradio_app = gr.Interface(
predict,
inputs=[
gr.Image(label="Select Image", sources=['upload', 'webcam'], type="numpy"),
gr.Textbox(label="Bounding Box Prompt (pixels)", placeholder="x1,y1,x2,y2")],
outputs=[gr.Image(label="SAM Mask", type="numpy", image_mode="L"), gr.Image(label="DIS-SAM Mask", type="numpy", image_mode="L")],
title="DIS-SAM",
examples=[
["./images/wire_shelf.jpg", "20,100,480,660"],
["./images/radio_telescope.jpg", "1130,320,4000,2920"],
["./images/bridge.jpg", ""],
["./images/tree.jpg", "70,110,2290,1800"],
["./images/bicycle.jpg", "135,235,2425,1580"],
["./images/capybara.jpg", "0,650,2000,1706"]
]
)
if __name__ == "__main__":
gradio_app.launch()