mostlycached commited on
Commit
b344378
·
verified ·
1 Parent(s): 28889af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -80
app.py CHANGED
@@ -2,11 +2,10 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import cv2
5
- from PIL import Image
6
  from transformers import SamModel, SamProcessor
7
  from diffusers import StableDiffusionInpaintPipeline
8
- import requests
9
- from io import BytesIO
10
 
11
  # Set up device
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -24,27 +23,35 @@ inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
24
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
25
  ).to(device)
26
 
27
- def get_sam_mask(image, points=None):
28
- """Get segmentation mask using SAM model"""
29
- if points is None:
30
- # If no points provided, use center point
31
- height, width = image.shape[:2]
32
- points = [[[width // 2, height // 2]]]
33
-
34
- # Convert to PIL if needed
35
- if not isinstance(image, Image.Image):
36
- image_pil = Image.fromarray(image)
37
  else:
38
- image_pil = image
39
 
40
- # Process the image and point prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  inputs = sam_processor(
42
- images=image_pil,
43
  input_points=points,
44
  return_tensors="pt"
45
  ).to(device)
46
 
47
- # Generate mask
48
  with torch.no_grad():
49
  outputs = sam_model(**inputs)
50
  masks = sam_processor.image_processor.post_process_masks(
@@ -53,86 +60,123 @@ def get_sam_mask(image, points=None):
53
  inputs["reshaped_input_sizes"].cpu()
54
  )
55
 
56
- # Get the mask
57
- mask = masks[0][0].numpy()
58
- return mask
 
 
 
 
 
 
 
59
 
60
- def adjust_aspect_ratio(image, mask, target_ratio, prompt=""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  """Adjust image to target aspect ratio while preserving important content"""
62
  # Convert PIL to numpy if needed
63
  if isinstance(image, Image.Image):
 
64
  image_np = np.array(image)
65
  else:
66
  image_np = image
 
67
 
 
68
  h, w = image_np.shape[:2]
69
  current_ratio = w / h
70
  target_ratio_value = eval(target_ratio.replace(':', '/'))
71
 
72
- # Determine if we need to add width or height
 
 
 
73
  if current_ratio < target_ratio_value:
74
  # Need to add width (outpaint left/right)
75
  new_width = int(h * target_ratio_value)
76
  new_height = h
77
-
78
- # Calculate padding
79
- pad_width = new_width - w
80
- pad_left = pad_width // 2
81
- pad_right = pad_width - pad_left
82
-
83
- # Create canvas with padding
84
- result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
85
- # Place original image in the center
86
- result[:, pad_left:pad_left+w, :] = image_np
87
-
88
- # Create mask for inpainting
89
- inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
90
- inpaint_mask[:, pad_left:pad_left+w] = 0
91
-
92
- # Perform outpainting using Stable Diffusion
93
- result = outpaint_regions(result, inpaint_mask, prompt)
94
-
95
  else:
96
  # Need to add height (outpaint top/bottom)
97
  new_width = w
98
  new_height = int(w / target_ratio_value)
99
-
100
- # Calculate padding
101
- pad_height = new_height - h
102
- pad_top = pad_height // 2
103
- pad_bottom = pad_height - pad_top
104
-
105
- # Create canvas with padding
106
- result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
107
- # Place original image in the center
108
- result[pad_top:pad_top+h, :, :] = image_np
109
-
110
- # Create mask for inpainting
111
- inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
112
- inpaint_mask[pad_top:pad_top+h, :] = 0
113
-
114
- # Perform outpainting using Stable Diffusion
115
- result = outpaint_regions(result, inpaint_mask, prompt)
116
 
117
- return result
118
-
119
- def outpaint_regions(image, mask, prompt):
120
- """Use Stable Diffusion to outpaint masked regions"""
121
- # Convert to PIL images
122
- image_pil = Image.fromarray(image)
 
 
 
 
 
 
 
123
  mask_pil = Image.fromarray(mask)
124
 
125
- # If prompt is empty, use a generic one
126
  if not prompt or prompt.strip() == "":
127
- prompt = "seamless extension of the image, same style, same scene"
 
 
 
128
 
129
- # Generate the outpainting
130
  output = inpaint_model(
131
  prompt=prompt,
132
- image=image_pil,
133
  mask_image=mask_pil,
134
  guidance_scale=7.5,
135
- num_inference_steps=25
136
  ).images[0]
137
 
138
  return np.array(output)
@@ -140,7 +184,7 @@ def outpaint_regions(image, mask, prompt):
140
  def process_image(input_image, target_ratio="16:9", prompt=""):
141
  """Main processing function for the Gradio interface"""
142
  try:
143
- # Convert from Gradio format
144
  if isinstance(input_image, dict) and 'image' in input_image:
145
  image = input_image['image']
146
  else:
@@ -152,11 +196,8 @@ def process_image(input_image, target_ratio="16:9", prompt=""):
152
  else:
153
  image_np = image
154
 
155
- # Get SAM mask to identify important regions
156
- mask = get_sam_mask(image_np)
157
-
158
  # Adjust aspect ratio while preserving content
159
- result = adjust_aspect_ratio(image_np, mask, target_ratio, prompt)
160
 
161
  # Convert result to PIL for visualization
162
  result_pil = Image.fromarray(result)
@@ -168,9 +209,9 @@ def process_image(input_image, target_ratio="16:9", prompt=""):
168
  return None
169
 
170
  # Create the Gradio interface
171
- with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo:
172
- gr.Markdown("# Automatic Aspect Ratio Adjuster")
173
- gr.Markdown("Upload an image, choose your target aspect ratio, and let the AI adjust it while preserving important content.")
174
 
175
  with gr.Row():
176
  with gr.Column():
@@ -178,7 +219,7 @@ with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo:
178
 
179
  with gr.Row():
180
  aspect_ratio = gr.Dropdown(
181
- choices=["16:9", "4:3", "1:1", "9:16", "3:4"],
182
  value="16:9",
183
  label="Target Aspect Ratio"
184
  )
@@ -201,9 +242,9 @@ with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo:
201
 
202
  gr.Markdown("""
203
  ## How it works
204
- 1. SAM (Segment Anything Model) identifies important content in your image
205
- 2. The algorithm calculates how to adjust the aspect ratio while preserving this content
206
- 3. Stable Diffusion fills in the new areas with AI-generated content that matches the original image
207
 
208
  ## Tips
209
  - For best results, provide a descriptive prompt that matches the scene
 
2
  import torch
3
  import numpy as np
4
  import cv2
5
+ from PIL import Image, ImageOps
6
  from transformers import SamModel, SamProcessor
7
  from diffusers import StableDiffusionInpaintPipeline
8
+ import os
 
9
 
10
  # Set up device
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
24
  ).to(device)
25
 
26
+ def get_importance_map(image, points=None):
27
+ """Get importance map using SAM model to identify key content regions"""
28
+ # Convert to numpy if needed
29
+ if isinstance(image, Image.Image):
30
+ image_np = np.array(image)
 
 
 
 
 
31
  else:
32
+ image_np = image
33
 
34
+ h, w = image_np.shape[:2]
35
+
36
+ # If no points provided, use grid sampling to identify important areas
37
+ if points is None:
38
+ # Create a grid of points to sample the image
39
+ x_points = np.linspace(w//4, 3*w//4, 5, dtype=int)
40
+ y_points = np.linspace(h//4, 3*h//4, 5, dtype=int)
41
+ grid_points = []
42
+ for y in y_points:
43
+ for x in x_points:
44
+ grid_points.append([x, y])
45
+ points = [grid_points]
46
+
47
+ # Process image through SAM
48
  inputs = sam_processor(
49
+ images=image_np,
50
  input_points=points,
51
  return_tensors="pt"
52
  ).to(device)
53
 
54
+ # Generate masks
55
  with torch.no_grad():
56
  outputs = sam_model(**inputs)
57
  masks = sam_processor.image_processor.post_process_masks(
 
60
  inputs["reshaped_input_sizes"].cpu()
61
  )
62
 
63
+ # Combine all masks to create importance map
64
+ importance_map = np.zeros((h, w), dtype=np.float32)
65
+ for i in range(len(masks[0])):
66
+ importance_map += masks[0][i].numpy().astype(np.float32)
67
+
68
+ # Normalize to 0-1
69
+ if importance_map.max() > 0:
70
+ importance_map = importance_map / importance_map.max()
71
+
72
+ return importance_map
73
 
74
+ def find_optimal_placement(importance_map, original_size, new_size):
75
+ """Find the optimal placement for the original image within the new canvas based on importance"""
76
+ oh, ow = original_size
77
+ nh, nw = new_size
78
+
79
+ # If the new size is smaller in any dimension, then just center it
80
+ if nh <= oh or nw <= ow:
81
+ x_offset = max(0, (nw - ow) // 2)
82
+ y_offset = max(0, (nh - oh) // 2)
83
+ return x_offset, y_offset
84
+
85
+ # Calculate all possible positions
86
+ possible_x = nw - ow + 1
87
+ possible_y = nh - oh + 1
88
+
89
+ best_score = -np.inf
90
+ best_x = 0
91
+ best_y = 0
92
+
93
+ # Create a border-weighted importance map (gives extra weight to content near borders)
94
+ y_coords, x_coords = np.ogrid[:oh, :ow]
95
+ border_weight = np.minimum(np.minimum(x_coords, ow-1-x_coords), np.minimum(y_coords, oh-1-y_coords))
96
+ border_weight = 1.0 - border_weight / border_weight.max()
97
+ weighted_importance = importance_map * (1.0 + 0.5 * border_weight)
98
+
99
+ # Optimize for 9 positions (corners, center of edges, and center)
100
+ positions = [
101
+ (0, 0), # Top-left
102
+ (0, (possible_y-1)//2), # Middle-left
103
+ (0, possible_y-1), # Bottom-left
104
+ ((possible_x-1)//2, 0), # Top-center
105
+ ((possible_x-1)//2, (possible_y-1)//2), # Center
106
+ ((possible_x-1)//2, possible_y-1), # Bottom-center
107
+ (possible_x-1, 0), # Top-right
108
+ (possible_x-1, (possible_y-1)//2), # Middle-right
109
+ (possible_x-1, possible_y-1) # Bottom-right
110
+ ]
111
+
112
+ # Find position with highest importance score
113
+ for x, y in positions:
114
+ # Calculate importance score for this position
115
+ score = weighted_importance.sum()
116
+ if score > best_score:
117
+ best_score = score
118
+ best_x = x
119
+ best_y = y
120
+
121
+ return best_x, best_y
122
+
123
+ def adjust_aspect_ratio(image, target_ratio, prompt=""):
124
  """Adjust image to target aspect ratio while preserving important content"""
125
  # Convert PIL to numpy if needed
126
  if isinstance(image, Image.Image):
127
+ image_pil = image
128
  image_np = np.array(image)
129
  else:
130
  image_np = image
131
+ image_pil = Image.fromarray(image_np)
132
 
133
+ # Get dimensions
134
  h, w = image_np.shape[:2]
135
  current_ratio = w / h
136
  target_ratio_value = eval(target_ratio.replace(':', '/'))
137
 
138
+ # Generate importance map to identify key regions
139
+ importance_map = get_importance_map(image_np)
140
+
141
+ # Calculate new dimensions
142
  if current_ratio < target_ratio_value:
143
  # Need to add width (outpaint left/right)
144
  new_width = int(h * target_ratio_value)
145
  new_height = h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  else:
147
  # Need to add height (outpaint top/bottom)
148
  new_width = w
149
  new_height = int(w / target_ratio_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ # Find optimal placement based on importance map
152
+ x_offset, y_offset = find_optimal_placement(importance_map, (h, w), (new_height, new_width))
153
+
154
+ # Create new canvas
155
+ result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
156
+ mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
157
+
158
+ # Place original image at calculated position
159
+ result[y_offset:y_offset+h, x_offset:x_offset+w] = image_np
160
+ mask[y_offset:y_offset+h, x_offset:x_offset+w] = 0
161
+
162
+ # Convert to PIL for inpainting
163
+ result_pil = Image.fromarray(result)
164
  mask_pil = Image.fromarray(mask)
165
 
166
+ # Use default prompt if none provided
167
  if not prompt or prompt.strip() == "":
168
+ if len(image_np.shape) == 3 and image_np.shape[2] == 4: # Check if image has alpha channel
169
+ prompt = "seamless extension of the image, same style and content"
170
+ else:
171
+ prompt = "seamless extension of the image, same style, same scene, consistent lighting"
172
 
173
+ # Perform outpainting using Stable Diffusion
174
  output = inpaint_model(
175
  prompt=prompt,
176
+ image=result_pil,
177
  mask_image=mask_pil,
178
  guidance_scale=7.5,
179
+ num_inference_steps=30
180
  ).images[0]
181
 
182
  return np.array(output)
 
184
  def process_image(input_image, target_ratio="16:9", prompt=""):
185
  """Main processing function for the Gradio interface"""
186
  try:
187
+ # Convert from Gradio format if needed
188
  if isinstance(input_image, dict) and 'image' in input_image:
189
  image = input_image['image']
190
  else:
 
196
  else:
197
  image_np = image
198
 
 
 
 
199
  # Adjust aspect ratio while preserving content
200
+ result = adjust_aspect_ratio(image_np, target_ratio, prompt)
201
 
202
  # Convert result to PIL for visualization
203
  result_pil = Image.fromarray(result)
 
209
  return None
210
 
211
  # Create the Gradio interface
212
+ with gr.Blocks(title="Smart Aspect Ratio Adjuster") as demo:
213
+ gr.Markdown("# Smart Aspect Ratio Adjuster")
214
+ gr.Markdown("Upload an image, choose your target aspect ratio, and the AI will adjust it while intelligently preserving important content.")
215
 
216
  with gr.Row():
217
  with gr.Column():
 
219
 
220
  with gr.Row():
221
  aspect_ratio = gr.Dropdown(
222
+ choices=["16:9", "4:3", "1:1", "9:16", "3:4", "2:1", "1:2"],
223
  value="16:9",
224
  label="Target Aspect Ratio"
225
  )
 
242
 
243
  gr.Markdown("""
244
  ## How it works
245
+ 1. **Content Analysis**: SAM (Segment Anything Model) identifies important regions in your image
246
+ 2. **Smart Placement**: The algorithm calculates optimal positioning to preserve key content
247
+ 3. **AI Outpainting**: Stable Diffusion fills in new areas with matching content
248
 
249
  ## Tips
250
  - For best results, provide a descriptive prompt that matches the scene