Update app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,29 @@
|
|
1 |
import time
|
2 |
-
import torch
|
3 |
import gc
|
4 |
-
|
|
|
5 |
from PIL import Image
|
6 |
from torchvision import transforms
|
7 |
import gradio as gr
|
8 |
|
|
|
|
|
|
|
|
|
9 |
def load_model():
|
10 |
-
# Fetch the config first (with trust_remote_code=True)
|
11 |
config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
|
12 |
-
|
13 |
-
# Ensure it's not treated as a seq2seq model
|
14 |
config.is_encoder_decoder = False
|
15 |
|
16 |
-
#
|
17 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
# Now load the model with our tweaked config
|
20 |
model = AutoModelForImageSegmentation.from_pretrained(
|
21 |
"zhengpeng7/BiRefNet_lite",
|
22 |
config=config,
|
@@ -28,14 +35,16 @@ def load_model():
|
|
28 |
model.eval()
|
29 |
return model, device
|
30 |
|
|
|
31 |
birefnet, device = load_model()
|
32 |
|
33 |
-
# Preprocessing
|
34 |
image_size = (1024, 1024)
|
35 |
transform_image = transforms.Compose([
|
36 |
transforms.Resize(image_size),
|
37 |
transforms.ToTensor(),
|
38 |
-
transforms.Normalize([0.485, 0.456, 0.406],
|
|
|
39 |
])
|
40 |
|
41 |
def run_inference(images, model, device):
|
@@ -44,13 +53,17 @@ def run_inference(images, model, device):
|
|
44 |
for img in images:
|
45 |
original_sizes.append(img.size)
|
46 |
inputs.append(transform_image(img))
|
47 |
-
input_tensor = torch.stack(inputs).to(device)
|
48 |
|
|
|
49 |
try:
|
50 |
with torch.no_grad():
|
51 |
-
# If the
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
54 |
except torch.OutOfMemoryError:
|
55 |
del input_tensor
|
56 |
torch.cuda.empty_cache()
|
@@ -70,40 +83,42 @@ def run_inference(images, model, device):
|
|
70 |
del input_tensor, preds
|
71 |
gc.collect()
|
72 |
torch.cuda.empty_cache()
|
|
|
73 |
return results
|
74 |
|
75 |
def binary_search_max(images):
|
76 |
-
# After OOM, try to find max feasible batch
|
77 |
low, high = 1, len(images)
|
78 |
-
best = None
|
79 |
-
|
80 |
while low <= high:
|
81 |
mid = (low + high) // 2
|
82 |
batch = images[:mid]
|
83 |
try:
|
|
|
84 |
global birefnet, device
|
85 |
-
birefnet, device = load_model()
|
86 |
res = run_inference(batch, birefnet, device)
|
87 |
-
best = res
|
88 |
-
best_count = mid
|
89 |
low = mid + 1
|
90 |
except torch.OutOfMemoryError:
|
91 |
high = mid - 1
|
|
|
92 |
return best, best_count
|
93 |
|
94 |
def extract_objects(filepaths):
|
95 |
images = [Image.open(p).convert("RGB") for p in filepaths]
|
96 |
start_time = time.time()
|
97 |
|
98 |
-
# First attempt: all images
|
99 |
try:
|
100 |
results = run_inference(images, birefnet, device)
|
101 |
end_time = time.time()
|
102 |
total_time = end_time - start_time
|
103 |
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
|
104 |
return results, summary
|
|
|
105 |
except torch.OutOfMemoryError:
|
106 |
-
# OOM
|
107 |
oom_time = time.time()
|
108 |
initial_attempt_time = oom_time - start_time
|
109 |
|
@@ -112,7 +127,7 @@ def extract_objects(filepaths):
|
|
112 |
total_time = end_time - start_time
|
113 |
|
114 |
if best is None:
|
115 |
-
# Not even 1 image
|
116 |
summary = (
|
117 |
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
|
118 |
f"Could not process even a single image.\n"
|
@@ -132,8 +147,8 @@ iface = gr.Interface(
|
|
132 |
fn=extract_objects,
|
133 |
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
|
134 |
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
|
135 |
-
title="BiRefNet Bulk Background Removal with
|
136 |
-
description="Upload
|
137 |
)
|
138 |
|
139 |
if __name__ == "__main__":
|
|
|
1 |
import time
|
|
|
2 |
import gc
|
3 |
+
import torch
|
4 |
+
|
5 |
from PIL import Image
|
6 |
from torchvision import transforms
|
7 |
import gradio as gr
|
8 |
|
9 |
+
from transformers import AutoConfig, AutoModelForImageSegmentation
|
10 |
+
|
11 |
+
# 1) Wrap config loading in a helper that monkey-patches a dummy get_text_config().
|
12 |
+
|
13 |
def load_model():
|
|
|
14 |
config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
|
|
|
|
|
15 |
config.is_encoder_decoder = False
|
16 |
|
17 |
+
# We define a dummy function that returns a minimal object
|
18 |
+
# with a tie_word_embeddings attribute, so tie_weights() won't fail.
|
19 |
+
def dummy_text_config(decoder=True):
|
20 |
+
class DummyTextConfig:
|
21 |
+
tie_word_embeddings = False
|
22 |
+
return DummyTextConfig()
|
23 |
+
|
24 |
+
# Patch the config so huggingface code won't blow up
|
25 |
+
setattr(config, "get_text_config", dummy_text_config)
|
26 |
|
|
|
27 |
model = AutoModelForImageSegmentation.from_pretrained(
|
28 |
"zhengpeng7/BiRefNet_lite",
|
29 |
config=config,
|
|
|
35 |
model.eval()
|
36 |
return model, device
|
37 |
|
38 |
+
# 2) Initialize global model & device
|
39 |
birefnet, device = load_model()
|
40 |
|
41 |
+
# 3) Preprocessing transform
|
42 |
image_size = (1024, 1024)
|
43 |
transform_image = transforms.Compose([
|
44 |
transforms.Resize(image_size),
|
45 |
transforms.ToTensor(),
|
46 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
47 |
+
[0.229, 0.224, 0.225])
|
48 |
])
|
49 |
|
50 |
def run_inference(images, model, device):
|
|
|
53 |
for img in images:
|
54 |
original_sizes.append(img.size)
|
55 |
inputs.append(transform_image(img))
|
|
|
56 |
|
57 |
+
input_tensor = torch.stack(inputs).to(device)
|
58 |
try:
|
59 |
with torch.no_grad():
|
60 |
+
# If the model returns multiple outputs, adapt as needed
|
61 |
+
output = model(input_tensor)
|
62 |
+
# The last element might be your segmentation mask. Adjust if needed:
|
63 |
+
# e.g. preds = output[-1] if it returns a list/tuple
|
64 |
+
# or preds = output.logits if it returns a named field
|
65 |
+
# The original example used `output[-1].sigmoid()`, so:
|
66 |
+
preds = output[-1].sigmoid().cpu()
|
67 |
except torch.OutOfMemoryError:
|
68 |
del input_tensor
|
69 |
torch.cuda.empty_cache()
|
|
|
83 |
del input_tensor, preds
|
84 |
gc.collect()
|
85 |
torch.cuda.empty_cache()
|
86 |
+
|
87 |
return results
|
88 |
|
89 |
def binary_search_max(images):
|
|
|
90 |
low, high = 1, len(images)
|
91 |
+
best, best_count = None, 0
|
92 |
+
|
93 |
while low <= high:
|
94 |
mid = (low + high) // 2
|
95 |
batch = images[:mid]
|
96 |
try:
|
97 |
+
# Re-load the model to avoid leftover memory fragmentation
|
98 |
global birefnet, device
|
99 |
+
birefnet, device = load_model()
|
100 |
res = run_inference(batch, birefnet, device)
|
101 |
+
best, best_count = res, mid
|
|
|
102 |
low = mid + 1
|
103 |
except torch.OutOfMemoryError:
|
104 |
high = mid - 1
|
105 |
+
|
106 |
return best, best_count
|
107 |
|
108 |
def extract_objects(filepaths):
|
109 |
images = [Image.open(p).convert("RGB") for p in filepaths]
|
110 |
start_time = time.time()
|
111 |
|
112 |
+
# First attempt: all images at once
|
113 |
try:
|
114 |
results = run_inference(images, birefnet, device)
|
115 |
end_time = time.time()
|
116 |
total_time = end_time - start_time
|
117 |
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
|
118 |
return results, summary
|
119 |
+
|
120 |
except torch.OutOfMemoryError:
|
121 |
+
# If it fails with OOM, do a fallback
|
122 |
oom_time = time.time()
|
123 |
initial_attempt_time = oom_time - start_time
|
124 |
|
|
|
127 |
total_time = end_time - start_time
|
128 |
|
129 |
if best is None:
|
130 |
+
# Not even 1 image can be processed
|
131 |
summary = (
|
132 |
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
|
133 |
f"Could not process even a single image.\n"
|
|
|
147 |
fn=extract_objects,
|
148 |
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
|
149 |
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
|
150 |
+
title="BiRefNet Bulk Background Removal (with fallback)",
|
151 |
+
description="Upload multiple images. If OOM occurs, we fallback to smaller batches."
|
152 |
)
|
153 |
|
154 |
if __name__ == "__main__":
|