Spaces:
Sleeping
Sleeping
app file
Browse files- Roboto-Regular.ttf +0 -0
- app.py +285 -0
Roboto-Regular.ttf
ADDED
Binary file (172 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import random
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
os.system("python -m pip install -e segment_anything")
|
6 |
+
os.system("python -m pip install -e GroundingDINO")
|
7 |
+
os.system("pip install --upgrade diffusers[torch]")
|
8 |
+
os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
|
9 |
+
os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo1.jpg")
|
10 |
+
os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth")
|
11 |
+
os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
|
12 |
+
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
|
13 |
+
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
|
14 |
+
sys.path.append(os.path.join(os.getcwd(), "Roboto-Regular"))
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torchvision
|
23 |
+
from PIL import Image, ImageDraw, ImageFont
|
24 |
+
|
25 |
+
# Grounding DINO
|
26 |
+
import GroundingDINO.groundingdino.datasets.transforms as T
|
27 |
+
from GroundingDINO.groundingdino.models import build_model
|
28 |
+
from GroundingDINO.groundingdino.util.slconfig import SLConfig
|
29 |
+
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
30 |
+
|
31 |
+
# segment anything
|
32 |
+
from segment_anything import build_sam, SamPredictor
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
# diffusers
|
36 |
+
import torch
|
37 |
+
from diffusers import StableDiffusionInpaintPipeline
|
38 |
+
|
39 |
+
# BLIP
|
40 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
41 |
+
|
42 |
+
|
43 |
+
def generate_caption(processor, blip_model, raw_image):
|
44 |
+
# unconditional image captioning
|
45 |
+
inputs = processor(raw_image, return_tensors="pt").to(
|
46 |
+
"cuda", torch.float16)
|
47 |
+
out = blip_model.generate(**inputs)
|
48 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
49 |
+
return caption
|
50 |
+
|
51 |
+
|
52 |
+
def transform_image(image_pil):
|
53 |
+
|
54 |
+
transform = T.Compose(
|
55 |
+
[
|
56 |
+
T.RandomResize([800], max_size=1333),
|
57 |
+
T.ToTensor(),
|
58 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
59 |
+
]
|
60 |
+
)
|
61 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
62 |
+
return image
|
63 |
+
|
64 |
+
|
65 |
+
def load_model(model_config_path, model_checkpoint_path, device):
|
66 |
+
args = SLConfig.fromfile(model_config_path)
|
67 |
+
args.device = device
|
68 |
+
model = build_model(args)
|
69 |
+
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
70 |
+
load_res = model.load_state_dict(
|
71 |
+
clean_state_dict(checkpoint["model"]), strict=False)
|
72 |
+
print(load_res)
|
73 |
+
_ = model.eval()
|
74 |
+
return model
|
75 |
+
|
76 |
+
|
77 |
+
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
|
78 |
+
#caption="all plastic.all metal.all paper.all glass.all cardboard.all wood.all rubber"
|
79 |
+
caption="plastic.metal.paper.glass.cardboard.wood.rubber"
|
80 |
+
caption = caption.lower()
|
81 |
+
caption = caption.strip()
|
82 |
+
if not caption.endswith("."):
|
83 |
+
caption = caption + "."
|
84 |
+
|
85 |
+
with torch.no_grad():
|
86 |
+
outputs = model(image[None], captions=[caption])
|
87 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
88 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
89 |
+
logits.shape[0]
|
90 |
+
|
91 |
+
# filter output
|
92 |
+
logits_filt = logits.clone()
|
93 |
+
boxes_filt = boxes.clone()
|
94 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
95 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
96 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
97 |
+
logits_filt.shape[0]
|
98 |
+
|
99 |
+
# get phrase
|
100 |
+
tokenlizer = model.tokenizer
|
101 |
+
tokenized = tokenlizer(caption)
|
102 |
+
# build pred
|
103 |
+
pred_phrases = []
|
104 |
+
scores = []
|
105 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
106 |
+
pred_phrase = get_phrases_from_posmap(
|
107 |
+
logit > text_threshold, tokenized, tokenlizer)
|
108 |
+
if with_logits:
|
109 |
+
pred_phrases.append(
|
110 |
+
pred_phrase + f"({str(logit.max().item())[:4]})")
|
111 |
+
else:
|
112 |
+
pred_phrases.append(pred_phrase)
|
113 |
+
scores.append(logit.max().item())
|
114 |
+
|
115 |
+
return boxes_filt, torch.Tensor(scores), pred_phrases
|
116 |
+
|
117 |
+
|
118 |
+
def draw_mask(mask, draw, random_color=False):
|
119 |
+
if random_color:
|
120 |
+
color = (random.randint(0, 255), random.randint(
|
121 |
+
0, 255), random.randint(0, 255), 153)
|
122 |
+
else:
|
123 |
+
color = (30, 144, 255, 153)
|
124 |
+
|
125 |
+
nonzero_coords = np.transpose(np.nonzero(mask))
|
126 |
+
|
127 |
+
for coord in nonzero_coords:
|
128 |
+
draw.point(coord[::-1], fill=color)
|
129 |
+
|
130 |
+
|
131 |
+
def draw_box(box, draw, label):
|
132 |
+
# random color
|
133 |
+
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
134 |
+
|
135 |
+
draw.rectangle(((box[0], box[1]), (box[2], box[3])),
|
136 |
+
outline=color, width=2)
|
137 |
+
|
138 |
+
if label:
|
139 |
+
#font = ImageFont.load_default()
|
140 |
+
font = ImageFont.truetype("Roboto-Regular.ttf", 20)
|
141 |
+
if hasattr(font, "getbbox"):
|
142 |
+
bbox = draw.textbbox((box[0], box[1]), str(label), font)
|
143 |
+
else:
|
144 |
+
w, h = draw.textsize(str(label), font)
|
145 |
+
bbox = (box[0], box[1], w + box[0], box[1] + h)
|
146 |
+
draw.rectangle(bbox, fill=color)
|
147 |
+
draw.text((box[0], box[1]), str(label), fill="white")
|
148 |
+
|
149 |
+
draw.text((box[0], box[1]), label)
|
150 |
+
|
151 |
+
|
152 |
+
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
153 |
+
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
154 |
+
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
155 |
+
sam_checkpoint = 'sam_vit_h_4b8939.pth'
|
156 |
+
output_dir = "outputs"
|
157 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
158 |
+
|
159 |
+
|
160 |
+
blip_processor = None
|
161 |
+
blip_model = None
|
162 |
+
groundingdino_model = None
|
163 |
+
sam_predictor = None
|
164 |
+
inpaint_pipeline = None
|
165 |
+
|
166 |
+
|
167 |
+
def run_grounded_sam(input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold):
|
168 |
+
|
169 |
+
global blip_processor, blip_model, groundingdino_model, sam_predictor, inpaint_pipeline
|
170 |
+
task_type="seg"
|
171 |
+
box_threshold=0.3
|
172 |
+
text_threshold= 0.25
|
173 |
+
iou_threshold= 0.8
|
174 |
+
# make dir
|
175 |
+
os.makedirs(output_dir, exist_ok=True)
|
176 |
+
# load image
|
177 |
+
image_pil = input_image.convert("RGB")
|
178 |
+
transformed_image = transform_image(image_pil)
|
179 |
+
|
180 |
+
if groundingdino_model is None:
|
181 |
+
groundingdino_model = load_model(
|
182 |
+
config_file, ckpt_filenmae, device=device)
|
183 |
+
|
184 |
+
|
185 |
+
# run grounding dino model
|
186 |
+
boxes_filt, scores, pred_phrases = get_grounding_output(
|
187 |
+
groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold
|
188 |
+
)
|
189 |
+
|
190 |
+
size = image_pil.size
|
191 |
+
|
192 |
+
# process boxes
|
193 |
+
H, W = size[1], size[0]
|
194 |
+
for i in range(boxes_filt.size(0)):
|
195 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
196 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
197 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
198 |
+
|
199 |
+
boxes_filt = boxes_filt.cpu()
|
200 |
+
|
201 |
+
# nms
|
202 |
+
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
203 |
+
nms_idx = torchvision.ops.nms(
|
204 |
+
boxes_filt, scores, iou_threshold).numpy().tolist()
|
205 |
+
boxes_filt = boxes_filt[nms_idx]
|
206 |
+
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
207 |
+
print(f"After NMS: {boxes_filt.shape[0]} boxes")
|
208 |
+
|
209 |
+
if task_type == 'seg':
|
210 |
+
if sam_predictor is None:
|
211 |
+
# initialize SAM
|
212 |
+
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
213 |
+
sam = build_sam(checkpoint=sam_checkpoint)
|
214 |
+
sam.to(device=device)
|
215 |
+
sam_predictor = SamPredictor(sam)
|
216 |
+
|
217 |
+
image = np.array(image_pil)
|
218 |
+
sam_predictor.set_image(image)
|
219 |
+
|
220 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
221 |
+
boxes_filt, image.shape[:2]).to(device)
|
222 |
+
|
223 |
+
masks, _, _ = sam_predictor.predict_torch(
|
224 |
+
point_coords=None,
|
225 |
+
point_labels=None,
|
226 |
+
boxes=transformed_boxes,
|
227 |
+
multimask_output=False,
|
228 |
+
)
|
229 |
+
# masks: [1, 1, 512, 512]
|
230 |
+
|
231 |
+
if task_type == 'seg':
|
232 |
+
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
233 |
+
|
234 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
235 |
+
for mask in masks:
|
236 |
+
draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
|
237 |
+
|
238 |
+
image_draw = ImageDraw.Draw(image_pil)
|
239 |
+
|
240 |
+
for box, label in zip(boxes_filt, pred_phrases):
|
241 |
+
draw_box(box, image_draw, label)
|
242 |
+
|
243 |
+
image_pil = image_pil.convert('RGBA')
|
244 |
+
image_pil.alpha_composite(mask_image)
|
245 |
+
return [image_pil, mask_image]
|
246 |
+
else:
|
247 |
+
print("task_type:{} error!".format(task_type))
|
248 |
+
|
249 |
+
|
250 |
+
if __name__ == "__main__":
|
251 |
+
parser = argparse.ArgumentParser("Marine Litter", add_help=True)
|
252 |
+
parser.add_argument("--debug", action="store_true",
|
253 |
+
help="using debug mode")
|
254 |
+
parser.add_argument("--share", action="store_true", help="share the app")
|
255 |
+
parser.add_argument('--no-gradio-queue', action="store_true",
|
256 |
+
help='path to the SAM checkpoint')
|
257 |
+
args = parser.parse_args()
|
258 |
+
|
259 |
+
print(args)
|
260 |
+
|
261 |
+
block = gr.Blocks()
|
262 |
+
if not args.no_gradio_queue:
|
263 |
+
block = block.queue()
|
264 |
+
|
265 |
+
with block:
|
266 |
+
with gr.Row():
|
267 |
+
with gr.Column():
|
268 |
+
input_image = gr.Image(
|
269 |
+
source='upload', type="pil")
|
270 |
+
#, value="demo1.jpg"
|
271 |
+
#task_type = gr.Dropdown(
|
272 |
+
# ["det", "seg", "inpainting", "automatic"], value="seg", label="task_type")
|
273 |
+
#text_prompt = gr.Textbox(label="Text Prompt", placeholder="all plastic.all metal.all paper.all glass.all cardboard.all wood.all rubber")
|
274 |
+
#inpaint_prompt = gr.Textbox(label="Inpaint Prompt", placeholder="A dinosaur, detailed, 4K.")
|
275 |
+
run_button = gr.Button(label="Run")
|
276 |
+
with gr.Column():
|
277 |
+
gallery = gr.Gallery(
|
278 |
+
label="Generated images", show_label=True, elem_id="gallery"
|
279 |
+
).style(preview=True, grid=2, object_fit="scale-down")
|
280 |
+
|
281 |
+
run_button.click(fn=run_grounded_sam, inputs=[
|
282 |
+
input_image], outputs=gallery)
|
283 |
+
|
284 |
+
block.launch(debug=args.debug, share=args.share, show_error=True)
|
285 |
+
#block.launch(debug=args.debug, show_error=True)
|