bgremoval / app.py
petergpt's picture
Update app.py
de63122 verified
raw
history blame contribute delete
5.18 kB
import time
import gc
import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
from transformers import AutoConfig, AutoModelForImageSegmentation
# 1) Wrap config loading in a helper that monkey-patches a dummy get_text_config().
def load_model():
config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
config.is_encoder_decoder = False
# We define a dummy function that returns a minimal object
# with a tie_word_embeddings attribute, so tie_weights() won't fail.
def dummy_text_config(decoder=True):
class DummyTextConfig:
tie_word_embeddings = False
return DummyTextConfig()
# Patch the config so huggingface code won't blow up
setattr(config, "get_text_config", dummy_text_config)
model = AutoModelForImageSegmentation.from_pretrained(
"zhengpeng7/BiRefNet_lite",
config=config,
trust_remote_code=True
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
return model, device
# 2) Initialize global model & device
birefnet, device = load_model()
# 3) Preprocessing transform
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def run_inference(images, model, device):
inputs = []
original_sizes = []
for img in images:
original_sizes.append(img.size)
inputs.append(transform_image(img))
input_tensor = torch.stack(inputs).to(device)
try:
with torch.no_grad():
# If the model returns multiple outputs, adapt as needed
output = model(input_tensor)
# The last element might be your segmentation mask. Adjust if needed:
# e.g. preds = output[-1] if it returns a list/tuple
# or preds = output.logits if it returns a named field
# The original example used `output[-1].sigmoid()`, so:
preds = output[-1].sigmoid().cpu()
except torch.OutOfMemoryError:
del input_tensor
torch.cuda.empty_cache()
raise
# Post-process
results = []
for i, img in enumerate(images):
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(original_sizes[i])
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
result.paste(img, mask=mask)
results.append(result)
# Cleanup
del input_tensor, preds
gc.collect()
torch.cuda.empty_cache()
return results
def binary_search_max(images):
low, high = 1, len(images)
best, best_count = None, 0
while low <= high:
mid = (low + high) // 2
batch = images[:mid]
try:
# Re-load the model to avoid leftover memory fragmentation
global birefnet, device
birefnet, device = load_model()
res = run_inference(batch, birefnet, device)
best, best_count = res, mid
low = mid + 1
except torch.OutOfMemoryError:
high = mid - 1
return best, best_count
def extract_objects(filepaths):
images = [Image.open(p).convert("RGB") for p in filepaths]
start_time = time.time()
# First attempt: all images at once
try:
results = run_inference(images, birefnet, device)
end_time = time.time()
total_time = end_time - start_time
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
return results, summary
except torch.OutOfMemoryError:
# If it fails with OOM, do a fallback
oom_time = time.time()
initial_attempt_time = oom_time - start_time
best, best_count = binary_search_max(images)
end_time = time.time()
total_time = end_time - start_time
if best is None:
# Not even 1 image can be processed
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Could not process even a single image.\n"
f"Total time including fallback attempts: {total_time:.2f}s."
)
return [], summary
else:
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Found that {best_count} images can be processed without OOM.\n"
f"Total time including fallback attempts: {total_time:.2f}s.\n"
f"Next time, try using up to {best_count} images."
)
return best, summary
iface = gr.Interface(
fn=extract_objects,
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
title="BiRefNet Bulk Background Removal (with fallback)",
description="Upload multiple images. If OOM occurs, we fallback to smaller batches."
)
if __name__ == "__main__":
iface.launch()