mostlycached commited on
Commit
b16f2d1
·
verified ·
1 Parent(s): 70b3d28

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from transformers import SamModel, SamProcessor
7
+ from diffusers import StableDiffusionInpaintPipeline
8
+ import requests
9
+ from io import BytesIO
10
+
11
+ # Set up device
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
+
15
+ # Load SAM model for segmentation
16
+ print("Loading SAM model...")
17
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
18
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
19
+
20
+ # Load Stable Diffusion for outpainting
21
+ print("Loading Stable Diffusion model...")
22
+ inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
23
+ "stabilityai/stable-diffusion-2-inpainting",
24
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
25
+ ).to(device)
26
+
27
+ def get_sam_mask(image, points=None):
28
+ """Get segmentation mask using SAM model"""
29
+ if points is None:
30
+ # If no points provided, use center point
31
+ height, width = image.shape[:2]
32
+ points = [[[width // 2, height // 2]]]
33
+
34
+ # Convert to PIL if needed
35
+ if not isinstance(image, Image.Image):
36
+ image_pil = Image.fromarray(image)
37
+ else:
38
+ image_pil = image
39
+
40
+ # Process the image and point prompts
41
+ inputs = sam_processor(
42
+ images=image_pil,
43
+ input_points=points,
44
+ return_tensors="pt"
45
+ ).to(device)
46
+
47
+ # Generate mask
48
+ with torch.no_grad():
49
+ outputs = sam_model(**inputs)
50
+ masks = sam_processor.image_processor.post_process_masks(
51
+ outputs.pred_masks.cpu(),
52
+ inputs["original_sizes"].cpu(),
53
+ inputs["reshaped_input_sizes"].cpu()
54
+ )
55
+
56
+ # Get the mask
57
+ mask = masks[0][0].numpy()
58
+ return mask
59
+
60
+ def adjust_aspect_ratio(image, mask, target_ratio, prompt=""):
61
+ """Adjust image to target aspect ratio while preserving important content"""
62
+ # Convert PIL to numpy if needed
63
+ if isinstance(image, Image.Image):
64
+ image_np = np.array(image)
65
+ else:
66
+ image_np = image
67
+
68
+ h, w = image_np.shape[:2]
69
+ current_ratio = w / h
70
+ target_ratio_value = eval(target_ratio.replace(':', '/'))
71
+
72
+ # Determine if we need to add width or height
73
+ if current_ratio < target_ratio_value:
74
+ # Need to add width (outpaint left/right)
75
+ new_width = int(h * target_ratio_value)
76
+ new_height = h
77
+
78
+ # Calculate padding
79
+ pad_width = new_width - w
80
+ pad_left = pad_width // 2
81
+ pad_right = pad_width - pad_left
82
+
83
+ # Create canvas with padding
84
+ result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
85
+ # Place original image in the center
86
+ result[:, pad_left:pad_left+w, :] = image_np
87
+
88
+ # Create mask for inpainting
89
+ inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
90
+ inpaint_mask[:, pad_left:pad_left+w] = 0
91
+
92
+ # Perform outpainting using Stable Diffusion
93
+ result = outpaint_regions(result, inpaint_mask, prompt)
94
+
95
+ else:
96
+ # Need to add height (outpaint top/bottom)
97
+ new_width = w
98
+ new_height = int(w / target_ratio_value)
99
+
100
+ # Calculate padding
101
+ pad_height = new_height - h
102
+ pad_top = pad_height // 2
103
+ pad_bottom = pad_height - pad_top
104
+
105
+ # Create canvas with padding
106
+ result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
107
+ # Place original image in the center
108
+ result[pad_top:pad_top+h, :, :] = image_np
109
+
110
+ # Create mask for inpainting
111
+ inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
112
+ inpaint_mask[pad_top:pad_top+h, :] = 0
113
+
114
+ # Perform outpainting using Stable Diffusion
115
+ result = outpaint_regions(result, inpaint_mask, prompt)
116
+
117
+ return result
118
+
119
+ def outpaint_regions(image, mask, prompt):
120
+ """Use Stable Diffusion to outpaint masked regions"""
121
+ # Convert to PIL images
122
+ image_pil = Image.fromarray(image)
123
+ mask_pil = Image.fromarray(mask)
124
+
125
+ # If prompt is empty, use a generic one
126
+ if not prompt or prompt.strip() == "":
127
+ prompt = "seamless extension of the image, same style, same scene"
128
+
129
+ # Generate the outpainting
130
+ output = inpaint_model(
131
+ prompt=prompt,
132
+ image=image_pil,
133
+ mask_image=mask_pil,
134
+ guidance_scale=7.5,
135
+ num_inference_steps=25
136
+ ).images[0]
137
+
138
+ return np.array(output)
139
+
140
+ def process_image(input_image, target_ratio="16:9", prompt=""):
141
+ """Main processing function for the Gradio interface"""
142
+ try:
143
+ # Convert from Gradio format
144
+ if isinstance(input_image, dict) and 'image' in input_image:
145
+ image = input_image['image']
146
+ else:
147
+ image = input_image
148
+
149
+ # Convert PIL to numpy if needed
150
+ if isinstance(image, Image.Image):
151
+ image_np = np.array(image)
152
+ else:
153
+ image_np = image
154
+
155
+ # Get SAM mask to identify important regions
156
+ mask = get_sam_mask(image_np)
157
+
158
+ # Adjust aspect ratio while preserving content
159
+ result = adjust_aspect_ratio(image_np, mask, target_ratio, prompt)
160
+
161
+ # Convert result to PIL for visualization
162
+ result_pil = Image.fromarray(result)
163
+
164
+ return result_pil
165
+
166
+ except Exception as e:
167
+ print(f"Error processing image: {e}")
168
+ return None
169
+
170
+ # Create the Gradio interface
171
+ with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo:
172
+ gr.Markdown("# Automatic Aspect Ratio Adjuster")
173
+ gr.Markdown("Upload an image, choose your target aspect ratio, and let the AI adjust it while preserving important content.")
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ input_image = gr.Image(label="Input Image")
178
+
179
+ with gr.Row():
180
+ aspect_ratio = gr.Dropdown(
181
+ choices=["16:9", "4:3", "1:1", "9:16", "3:4"],
182
+ value="16:9",
183
+ label="Target Aspect Ratio"
184
+ )
185
+
186
+ prompt = gr.Textbox(
187
+ label="Outpainting Prompt (optional)",
188
+ placeholder="Describe the scene for better outpainting"
189
+ )
190
+
191
+ submit_btn = gr.Button("Process Image")
192
+
193
+ with gr.Column():
194
+ output_image = gr.Image(label="Processed Image")
195
+
196
+ submit_btn.click(
197
+ process_image,
198
+ inputs=[input_image, aspect_ratio, prompt],
199
+ outputs=output_image
200
+ )
201
+
202
+ gr.Markdown("""
203
+ ## How it works
204
+ 1. SAM (Segment Anything Model) identifies important content in your image
205
+ 2. The algorithm calculates how to adjust the aspect ratio while preserving this content
206
+ 3. Stable Diffusion fills in the new areas with AI-generated content that matches the original image
207
+
208
+ ## Tips
209
+ - For best results, provide a descriptive prompt that matches the scene
210
+ - Try different aspect ratios to see what works best
211
+ - The model works best with clear, well-lit images
212
+ """)
213
+
214
+ # Launch the app
215
+ if __name__ == "__main__":
216
+ demo.launch()