petergpt commited on
Commit
de63122
Β·
verified Β·
1 Parent(s): cd88311

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -25
app.py CHANGED
@@ -1,22 +1,29 @@
1
  import time
2
- import torch
3
  import gc
4
- from transformers import AutoConfig, AutoModelForImageSegmentation
 
5
  from PIL import Image
6
  from torchvision import transforms
7
  import gradio as gr
8
 
 
 
 
 
9
  def load_model():
10
- # Fetch the config first (with trust_remote_code=True)
11
  config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
12
-
13
- # Ensure it's not treated as a seq2seq model
14
  config.is_encoder_decoder = False
15
 
16
- # Optionally, block calls to get_text_config if needed:
17
- # config.get_text_config = lambda decoder=True: None
 
 
 
 
 
 
 
18
 
19
- # Now load the model with our tweaked config
20
  model = AutoModelForImageSegmentation.from_pretrained(
21
  "zhengpeng7/BiRefNet_lite",
22
  config=config,
@@ -28,14 +35,16 @@ def load_model():
28
  model.eval()
29
  return model, device
30
 
 
31
  birefnet, device = load_model()
32
 
33
- # Preprocessing
34
  image_size = (1024, 1024)
35
  transform_image = transforms.Compose([
36
  transforms.Resize(image_size),
37
  transforms.ToTensor(),
38
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
 
39
  ])
40
 
41
  def run_inference(images, model, device):
@@ -44,13 +53,17 @@ def run_inference(images, model, device):
44
  for img in images:
45
  original_sizes.append(img.size)
46
  inputs.append(transform_image(img))
47
- input_tensor = torch.stack(inputs).to(device)
48
 
 
49
  try:
50
  with torch.no_grad():
51
- # If the last layer is returned as [-1],
52
- # adjust accordingly or see how your model outputs are structured
53
- preds = model(input_tensor)[-1].sigmoid().cpu()
 
 
 
 
54
  except torch.OutOfMemoryError:
55
  del input_tensor
56
  torch.cuda.empty_cache()
@@ -70,40 +83,42 @@ def run_inference(images, model, device):
70
  del input_tensor, preds
71
  gc.collect()
72
  torch.cuda.empty_cache()
 
73
  return results
74
 
75
  def binary_search_max(images):
76
- # After OOM, try to find max feasible batch
77
  low, high = 1, len(images)
78
- best = None
79
- best_count = 0
80
  while low <= high:
81
  mid = (low + high) // 2
82
  batch = images[:mid]
83
  try:
 
84
  global birefnet, device
85
- birefnet, device = load_model() # re-init to reduce memory fragmentation
86
  res = run_inference(batch, birefnet, device)
87
- best = res
88
- best_count = mid
89
  low = mid + 1
90
  except torch.OutOfMemoryError:
91
  high = mid - 1
 
92
  return best, best_count
93
 
94
  def extract_objects(filepaths):
95
  images = [Image.open(p).convert("RGB") for p in filepaths]
96
  start_time = time.time()
97
 
98
- # First attempt: all images
99
  try:
100
  results = run_inference(images, birefnet, device)
101
  end_time = time.time()
102
  total_time = end_time - start_time
103
  summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
104
  return results, summary
 
105
  except torch.OutOfMemoryError:
106
- # OOM occurred, try fallback
107
  oom_time = time.time()
108
  initial_attempt_time = oom_time - start_time
109
 
@@ -112,7 +127,7 @@ def extract_objects(filepaths):
112
  total_time = end_time - start_time
113
 
114
  if best is None:
115
- # Not even 1 image works
116
  summary = (
117
  f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
118
  f"Could not process even a single image.\n"
@@ -132,8 +147,8 @@ iface = gr.Interface(
132
  fn=extract_objects,
133
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
134
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
135
- title="BiRefNet Bulk Background Removal with On-Demand Fallback",
136
- description="Upload as many images as you want. If OOM occurs, fallback logic will find the max feasible number."
137
  )
138
 
139
  if __name__ == "__main__":
 
1
  import time
 
2
  import gc
3
+ import torch
4
+
5
  from PIL import Image
6
  from torchvision import transforms
7
  import gradio as gr
8
 
9
+ from transformers import AutoConfig, AutoModelForImageSegmentation
10
+
11
+ # 1) Wrap config loading in a helper that monkey-patches a dummy get_text_config().
12
+
13
  def load_model():
 
14
  config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
 
 
15
  config.is_encoder_decoder = False
16
 
17
+ # We define a dummy function that returns a minimal object
18
+ # with a tie_word_embeddings attribute, so tie_weights() won't fail.
19
+ def dummy_text_config(decoder=True):
20
+ class DummyTextConfig:
21
+ tie_word_embeddings = False
22
+ return DummyTextConfig()
23
+
24
+ # Patch the config so huggingface code won't blow up
25
+ setattr(config, "get_text_config", dummy_text_config)
26
 
 
27
  model = AutoModelForImageSegmentation.from_pretrained(
28
  "zhengpeng7/BiRefNet_lite",
29
  config=config,
 
35
  model.eval()
36
  return model, device
37
 
38
+ # 2) Initialize global model & device
39
  birefnet, device = load_model()
40
 
41
+ # 3) Preprocessing transform
42
  image_size = (1024, 1024)
43
  transform_image = transforms.Compose([
44
  transforms.Resize(image_size),
45
  transforms.ToTensor(),
46
+ transforms.Normalize([0.485, 0.456, 0.406],
47
+ [0.229, 0.224, 0.225])
48
  ])
49
 
50
  def run_inference(images, model, device):
 
53
  for img in images:
54
  original_sizes.append(img.size)
55
  inputs.append(transform_image(img))
 
56
 
57
+ input_tensor = torch.stack(inputs).to(device)
58
  try:
59
  with torch.no_grad():
60
+ # If the model returns multiple outputs, adapt as needed
61
+ output = model(input_tensor)
62
+ # The last element might be your segmentation mask. Adjust if needed:
63
+ # e.g. preds = output[-1] if it returns a list/tuple
64
+ # or preds = output.logits if it returns a named field
65
+ # The original example used `output[-1].sigmoid()`, so:
66
+ preds = output[-1].sigmoid().cpu()
67
  except torch.OutOfMemoryError:
68
  del input_tensor
69
  torch.cuda.empty_cache()
 
83
  del input_tensor, preds
84
  gc.collect()
85
  torch.cuda.empty_cache()
86
+
87
  return results
88
 
89
  def binary_search_max(images):
 
90
  low, high = 1, len(images)
91
+ best, best_count = None, 0
92
+
93
  while low <= high:
94
  mid = (low + high) // 2
95
  batch = images[:mid]
96
  try:
97
+ # Re-load the model to avoid leftover memory fragmentation
98
  global birefnet, device
99
+ birefnet, device = load_model()
100
  res = run_inference(batch, birefnet, device)
101
+ best, best_count = res, mid
 
102
  low = mid + 1
103
  except torch.OutOfMemoryError:
104
  high = mid - 1
105
+
106
  return best, best_count
107
 
108
  def extract_objects(filepaths):
109
  images = [Image.open(p).convert("RGB") for p in filepaths]
110
  start_time = time.time()
111
 
112
+ # First attempt: all images at once
113
  try:
114
  results = run_inference(images, birefnet, device)
115
  end_time = time.time()
116
  total_time = end_time - start_time
117
  summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
118
  return results, summary
119
+
120
  except torch.OutOfMemoryError:
121
+ # If it fails with OOM, do a fallback
122
  oom_time = time.time()
123
  initial_attempt_time = oom_time - start_time
124
 
 
127
  total_time = end_time - start_time
128
 
129
  if best is None:
130
+ # Not even 1 image can be processed
131
  summary = (
132
  f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
133
  f"Could not process even a single image.\n"
 
147
  fn=extract_objects,
148
  inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
149
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
150
+ title="BiRefNet Bulk Background Removal (with fallback)",
151
+ description="Upload multiple images. If OOM occurs, we fallback to smaller batches."
152
  )
153
 
154
  if __name__ == "__main__":