Spaces:
Sleeping
Sleeping
Modified with Examples
Browse filesModified with Examples.
app.py
CHANGED
@@ -6,6 +6,8 @@ from huggingface_hub import hf_hub_download
|
|
6 |
import os
|
7 |
from pathlib import Path
|
8 |
import traceback
|
|
|
|
|
9 |
|
10 |
# Reuse the same load_learned_embed_in_clip and Distance_loss functions
|
11 |
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
@@ -92,7 +94,6 @@ class StyleGenerator:
|
|
92 |
"Bird Style"
|
93 |
]
|
94 |
self.is_initialized = False
|
95 |
-
# Check if CUDA is available
|
96 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
97 |
if self.device == "cpu":
|
98 |
print("NVIDIA GPU not found. Running on CPU (this will be slower)")
|
@@ -132,63 +133,44 @@ class StyleGenerator:
|
|
132 |
print(traceback.format_exc())
|
133 |
raise
|
134 |
|
135 |
-
def
|
136 |
-
if not self.is_initialized:
|
137 |
-
self.initialize_model()
|
138 |
-
|
139 |
-
images = []
|
140 |
-
style_names = []
|
141 |
-
|
142 |
try:
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
retain_graph=False,
|
159 |
-
only_inputs=True
|
160 |
-
)[0]
|
161 |
-
|
162 |
-
# Update latents
|
163 |
-
with torch.no_grad():
|
164 |
-
latents = latents - 0.1 * grads
|
165 |
-
|
166 |
-
except Exception as e:
|
167 |
-
print(f"Error in callback: {e}")
|
168 |
-
return latents
|
169 |
-
|
170 |
-
return latents
|
171 |
-
|
172 |
-
for style_token, style_name in zip(self.style_tokens, self.style_names):
|
173 |
-
styled_prompt = f"{prompt}, {style_token}"
|
174 |
-
style_names.append(style_name)
|
175 |
-
|
176 |
-
# Disable autocast for better gradient computation
|
177 |
-
image = self.pipe(
|
178 |
styled_prompt,
|
179 |
-
num_inference_steps=
|
180 |
-
guidance_scale=
|
181 |
-
|
182 |
-
callback_steps=5
|
183 |
).images[0]
|
184 |
-
|
185 |
-
images.append(image)
|
186 |
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
except Exception as e:
|
190 |
-
print(f"Error
|
191 |
-
print(traceback.format_exc())
|
192 |
raise
|
193 |
|
194 |
def callback_fn(self, i, t, latents):
|
@@ -219,38 +201,44 @@ class StyleGenerator:
|
|
219 |
|
220 |
return latents
|
221 |
|
222 |
-
def
|
223 |
try:
|
224 |
generator = StyleGenerator.get_instance()
|
225 |
if not generator.is_initialized:
|
226 |
generator.initialize_model()
|
227 |
|
228 |
-
|
229 |
-
regular_images, style_names = generator.generate_images(prompt, apply_loss=False)
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
except Exception as e:
|
237 |
-
print(f"Error in
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
240 |
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
print(
|
251 |
-
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
254 |
|
255 |
# Create a more beautiful interface with custom styling
|
256 |
with gr.Blocks(css="""
|
@@ -265,73 +253,43 @@ with gr.Blocks(css="""
|
|
265 |
border: 1px solid #374151;
|
266 |
color: #f3f4f6;
|
267 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
""") as iface:
|
269 |
-
# Header section
|
270 |
gr.Markdown(
|
271 |
"""
|
272 |
-
<div class="dark-theme" style="text-align: center;
|
273 |
# π¨ AI Style Transfer Studio
|
274 |
-
### Transform your ideas into artistic masterpieces
|
275 |
</div>
|
276 |
"""
|
277 |
)
|
278 |
|
279 |
-
#
|
280 |
-
def generate_single_style(prompt, selected_style):
|
281 |
-
try:
|
282 |
-
generator = StyleGenerator.get_instance()
|
283 |
-
if not generator.is_initialized:
|
284 |
-
generator.initialize_model()
|
285 |
-
|
286 |
-
# Find the index of the selected style
|
287 |
-
style_idx = generator.style_names.index(generator.style_names[selected_style])
|
288 |
-
|
289 |
-
# Generate single image with selected style
|
290 |
-
styled_prompt = f"{prompt}, {generator.style_tokens[style_idx]}"
|
291 |
-
|
292 |
-
# Set seed for reproducibility
|
293 |
-
generator_seed = 42
|
294 |
-
torch.manual_seed(generator_seed)
|
295 |
-
if generator.device == "cuda":
|
296 |
-
torch.cuda.manual_seed(generator_seed)
|
297 |
-
|
298 |
-
# Generate base image
|
299 |
-
with autocast(generator.device):
|
300 |
-
base_image = generator.pipe(
|
301 |
-
styled_prompt,
|
302 |
-
num_inference_steps=50,
|
303 |
-
guidance_scale=7.5,
|
304 |
-
generator=torch.Generator(generator.device).manual_seed(generator_seed)
|
305 |
-
).images[0]
|
306 |
-
|
307 |
-
# Generate same image with loss
|
308 |
-
with autocast(generator.device):
|
309 |
-
loss_image = generator.pipe(
|
310 |
-
styled_prompt,
|
311 |
-
num_inference_steps=50,
|
312 |
-
guidance_scale=7.5,
|
313 |
-
callback=generator.callback_fn,
|
314 |
-
callback_steps=5,
|
315 |
-
generator=torch.Generator(generator.device).manual_seed(generator_seed)
|
316 |
-
).images[0]
|
317 |
-
|
318 |
-
return [
|
319 |
-
gr.update(visible=False), # error_message
|
320 |
-
base_image, # original_image
|
321 |
-
loss_image # loss_image
|
322 |
-
]
|
323 |
-
except Exception as e:
|
324 |
-
print(f"Error in generate_single_style: {e}")
|
325 |
-
return [
|
326 |
-
gr.update(value=f"Error: {str(e)}", visible=True), # error_message
|
327 |
-
None, # original_image
|
328 |
-
None # loss_image
|
329 |
-
]
|
330 |
-
|
331 |
-
# Main content
|
332 |
with gr.Row():
|
333 |
-
|
334 |
-
with gr.Column(scale=1, min_width=300):
|
335 |
gr.Markdown("## π― Controls")
|
336 |
|
337 |
prompt = gr.Textbox(
|
@@ -359,29 +317,88 @@ with gr.Blocks(css="""
|
|
359 |
size="lg"
|
360 |
)
|
361 |
|
362 |
-
# Error messages
|
363 |
error_message = gr.Markdown(visible=False)
|
364 |
-
|
365 |
-
# Style description
|
366 |
style_description = gr.Markdown()
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
# Info section
|
386 |
with gr.Row():
|
387 |
with gr.Column():
|
@@ -441,7 +458,7 @@ with gr.Blocks(css="""
|
|
441 |
"Specialized in capturing the beauty of nature and wildlife"
|
442 |
]
|
443 |
styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"]
|
444 |
-
return f"### Selected: {styles[style_idx]}\n{descriptions[style_idx]}"
|
445 |
|
446 |
style_radio.change(
|
447 |
fn=update_style_description,
|
@@ -449,7 +466,6 @@ with gr.Blocks(css="""
|
|
449 |
outputs=style_description
|
450 |
)
|
451 |
|
452 |
-
# Connect the generate button
|
453 |
generate_btn.click(
|
454 |
fn=generate_single_style,
|
455 |
inputs=[prompt, style_radio],
|
|
|
6 |
import os
|
7 |
from pathlib import Path
|
8 |
import traceback
|
9 |
+
import glob
|
10 |
+
from PIL import Image
|
11 |
|
12 |
# Reuse the same load_learned_embed_in_clip and Distance_loss functions
|
13 |
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
|
|
94 |
"Bird Style"
|
95 |
]
|
96 |
self.is_initialized = False
|
|
|
97 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
98 |
if self.device == "cpu":
|
99 |
print("NVIDIA GPU not found. Running on CPU (this will be slower)")
|
|
|
133 |
print(traceback.format_exc())
|
134 |
raise
|
135 |
|
136 |
+
def generate_single_style(self, prompt, selected_style):
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
try:
|
138 |
+
# Find the index of the selected style
|
139 |
+
style_idx = self.style_names.index(self.style_names[selected_style])
|
140 |
+
|
141 |
+
# Generate single image with selected style
|
142 |
+
styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}"
|
143 |
+
|
144 |
+
# Set seed for reproducibility
|
145 |
+
generator_seed = 42
|
146 |
+
torch.manual_seed(generator_seed)
|
147 |
+
if self.device == "cuda":
|
148 |
+
torch.cuda.manual_seed(generator_seed)
|
149 |
+
|
150 |
+
# Generate base image
|
151 |
+
with autocast(self.device):
|
152 |
+
base_image = self.pipe(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
styled_prompt,
|
154 |
+
num_inference_steps=50,
|
155 |
+
guidance_scale=7.5,
|
156 |
+
generator=torch.Generator(self.device).manual_seed(generator_seed)
|
|
|
157 |
).images[0]
|
|
|
|
|
158 |
|
159 |
+
# Generate same image with loss
|
160 |
+
with autocast(self.device):
|
161 |
+
loss_image = self.pipe(
|
162 |
+
styled_prompt,
|
163 |
+
num_inference_steps=50,
|
164 |
+
guidance_scale=7.5,
|
165 |
+
callback=self.callback_fn,
|
166 |
+
callback_steps=5,
|
167 |
+
generator=torch.Generator(self.device).manual_seed(generator_seed)
|
168 |
+
).images[0]
|
169 |
+
|
170 |
+
return base_image, loss_image
|
171 |
|
172 |
except Exception as e:
|
173 |
+
print(f"Error in generate_single_style: {e}")
|
|
|
174 |
raise
|
175 |
|
176 |
def callback_fn(self, i, t, latents):
|
|
|
201 |
|
202 |
return latents
|
203 |
|
204 |
+
def generate_single_style(prompt, selected_style):
|
205 |
try:
|
206 |
generator = StyleGenerator.get_instance()
|
207 |
if not generator.is_initialized:
|
208 |
generator.initialize_model()
|
209 |
|
210 |
+
base_image, loss_image = generator.generate_single_style(prompt, selected_style)
|
|
|
211 |
|
212 |
+
return [
|
213 |
+
gr.update(visible=False), # error_message
|
214 |
+
base_image, # original_image
|
215 |
+
loss_image # loss_image
|
216 |
+
]
|
217 |
except Exception as e:
|
218 |
+
print(f"Error in generate_single_style: {e}")
|
219 |
+
return [
|
220 |
+
gr.update(value=f"Error: {str(e)}", visible=True), # error_message
|
221 |
+
None, # original_image
|
222 |
+
None # loss_image
|
223 |
+
]
|
224 |
|
225 |
+
# Add at the start of your script
|
226 |
+
def debug_image_paths():
|
227 |
+
output_dir = Path("Outputs")
|
228 |
+
enhanced_dir = output_dir / "Color_Enhanced"
|
229 |
+
print(f"\nChecking image paths:")
|
230 |
+
print(f"Current working directory: {Path.cwd()}")
|
231 |
+
print(f"Looking for images in: {enhanced_dir.absolute()}")
|
232 |
+
|
233 |
+
if enhanced_dir.exists():
|
234 |
+
print("\nFound files:")
|
235 |
+
for file in enhanced_dir.glob("*.webp"):
|
236 |
+
print(f"- {file.name}")
|
237 |
+
else:
|
238 |
+
print("\nDirectory not found!")
|
239 |
+
|
240 |
+
# Call this function before creating the interface
|
241 |
+
debug_image_paths()
|
242 |
|
243 |
# Create a more beautiful interface with custom styling
|
244 |
with gr.Blocks(css="""
|
|
|
253 |
border: 1px solid #374151;
|
254 |
color: #f3f4f6;
|
255 |
}
|
256 |
+
/* Enhanced Tab Styling */
|
257 |
+
.tabs.svelte-710i53 {
|
258 |
+
margin-bottom: 0 !important;
|
259 |
+
}
|
260 |
+
.tab-nav.svelte-710i53 {
|
261 |
+
background: transparent !important;
|
262 |
+
border: none !important;
|
263 |
+
padding: 12px 24px !important;
|
264 |
+
margin: 0 2px !important;
|
265 |
+
color: #9CA3AF !important;
|
266 |
+
font-weight: 500 !important;
|
267 |
+
transition: all 0.2s ease !important;
|
268 |
+
border-bottom: 2px solid transparent !important;
|
269 |
+
}
|
270 |
+
.tab-nav.svelte-710i53.selected {
|
271 |
+
background: transparent !important;
|
272 |
+
color: #F3F4F6 !important;
|
273 |
+
border-bottom: 2px solid #6366F1 !important;
|
274 |
+
}
|
275 |
+
.tab-nav.svelte-710i53:hover {
|
276 |
+
color: #F3F4F6 !important;
|
277 |
+
border-bottom: 2px solid #4F46E5 !important;
|
278 |
+
}
|
279 |
""") as iface:
|
280 |
+
# Header section
|
281 |
gr.Markdown(
|
282 |
"""
|
283 |
+
<div class="dark-theme" style="text-align: center;">
|
284 |
# π¨ AI Style Transfer Studio
|
285 |
+
### Transform your ideas into artistic masterpieces
|
286 |
</div>
|
287 |
"""
|
288 |
)
|
289 |
|
290 |
+
# Controls section
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
with gr.Row():
|
292 |
+
with gr.Column(scale=1):
|
|
|
293 |
gr.Markdown("## π― Controls")
|
294 |
|
295 |
prompt = gr.Textbox(
|
|
|
317 |
size="lg"
|
318 |
)
|
319 |
|
|
|
320 |
error_message = gr.Markdown(visible=False)
|
|
|
|
|
321 |
style_description = gr.Markdown()
|
322 |
+
|
323 |
+
# Generated Images
|
324 |
+
with gr.Row():
|
325 |
+
with gr.Column():
|
326 |
+
original_image = gr.Image(
|
327 |
+
label="Original Style",
|
328 |
+
show_label=True,
|
329 |
+
height=300
|
330 |
+
)
|
331 |
+
with gr.Column():
|
332 |
+
loss_image = gr.Image(
|
333 |
+
label="Color Enhanced",
|
334 |
+
show_label=True,
|
335 |
+
height=300
|
336 |
+
)
|
337 |
+
|
338 |
+
# Example Gallery
|
339 |
+
gr.Markdown(
|
340 |
+
"""
|
341 |
+
<div class="dark-theme">
|
342 |
+
## π Example Gallery
|
343 |
+
Compare original and enhanced versions for each style:
|
344 |
+
</div>
|
345 |
+
"""
|
346 |
+
)
|
347 |
+
|
348 |
+
# Example Images
|
349 |
+
with gr.Row():
|
350 |
+
try:
|
351 |
+
output_dir = Path("Outputs")
|
352 |
+
original_dir = output_dir
|
353 |
+
enhanced_dir = output_dir / "Color_Enhanced"
|
354 |
+
|
355 |
+
if enhanced_dir.exists():
|
356 |
+
original_images = {
|
357 |
+
Path(f).stem.split('_example')[0]: f
|
358 |
+
for f in original_dir.glob("*.webp")
|
359 |
+
if '_example' in f.name
|
360 |
+
}
|
361 |
+
enhanced_images = {
|
362 |
+
Path(f).stem.split('_example')[0]: f
|
363 |
+
for f in enhanced_dir.glob("*.webp")
|
364 |
+
if '_example' in f.name
|
365 |
+
}
|
366 |
+
|
367 |
+
styles = [
|
368 |
+
("ronaldo", "Ronaldo Style"),
|
369 |
+
("canna_lily", "Canna Lily"),
|
370 |
+
("three_stooges", "Three Stooges"),
|
371 |
+
("pop_art", "Pop Art"),
|
372 |
+
("bird_style", "Bird Style")
|
373 |
+
]
|
374 |
+
|
375 |
+
# Create a grid of all styles
|
376 |
+
for style_key, style_name in styles:
|
377 |
+
if style_key in original_images and style_key in enhanced_images:
|
378 |
+
with gr.Row():
|
379 |
+
gr.Markdown(f"### {style_name}")
|
380 |
+
with gr.Row():
|
381 |
+
with gr.Column(scale=1):
|
382 |
+
gr.Image(
|
383 |
+
value=str(original_images[style_key]),
|
384 |
+
label="Original",
|
385 |
+
show_label=True,
|
386 |
+
height=180
|
387 |
+
)
|
388 |
+
with gr.Column(scale=1):
|
389 |
+
gr.Image(
|
390 |
+
value=str(enhanced_images[style_key]),
|
391 |
+
label="Color Enhanced",
|
392 |
+
show_label=True,
|
393 |
+
height=180
|
394 |
+
)
|
395 |
+
# Add a small spacing between styles
|
396 |
+
gr.Markdown("<div style='margin: 10px 0;'></div>")
|
397 |
+
|
398 |
+
except Exception as e:
|
399 |
+
print(f"Error in example gallery: {e}")
|
400 |
+
gr.Markdown(f"Error loading example gallery: {str(e)}")
|
401 |
+
|
402 |
# Info section
|
403 |
with gr.Row():
|
404 |
with gr.Column():
|
|
|
458 |
"Specialized in capturing the beauty of nature and wildlife"
|
459 |
]
|
460 |
styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"]
|
461 |
+
return f"### Selected Style: {styles[style_idx]}\n{descriptions[style_idx]}"
|
462 |
|
463 |
style_radio.change(
|
464 |
fn=update_style_description,
|
|
|
466 |
outputs=style_description
|
467 |
)
|
468 |
|
|
|
469 |
generate_btn.click(
|
470 |
fn=generate_single_style,
|
471 |
inputs=[prompt, style_radio],
|