padmanabhbosamia commited on
Commit
8e1a13f
Β·
verified Β·
1 Parent(s): e77c6a3

Modified with Examples

Browse files

Modified with Examples.

Files changed (1) hide show
  1. app.py +173 -157
app.py CHANGED
@@ -6,6 +6,8 @@ from huggingface_hub import hf_hub_download
6
  import os
7
  from pathlib import Path
8
  import traceback
 
 
9
 
10
  # Reuse the same load_learned_embed_in_clip and Distance_loss functions
11
  def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
@@ -92,7 +94,6 @@ class StyleGenerator:
92
  "Bird Style"
93
  ]
94
  self.is_initialized = False
95
- # Check if CUDA is available
96
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
97
  if self.device == "cpu":
98
  print("NVIDIA GPU not found. Running on CPU (this will be slower)")
@@ -132,63 +133,44 @@ class StyleGenerator:
132
  print(traceback.format_exc())
133
  raise
134
 
135
- def generate_images(self, prompt, apply_loss=False, num_inference_steps=50, guidance_scale=7.5):
136
- if not self.is_initialized:
137
- self.initialize_model()
138
-
139
- images = []
140
- style_names = []
141
-
142
  try:
143
- def callback_fn(i, t, latents):
144
- if i % 5 == 0 and apply_loss:
145
- try:
146
- # Ensure latents are in the correct format and require gradients
147
- latents = latents.float()
148
- latents.requires_grad_(True)
149
-
150
- # Compute loss
151
- loss = Distance_loss(latents)
152
-
153
- # Compute gradients manually
154
- grads = torch.autograd.grad(
155
- outputs=loss,
156
- inputs=latents,
157
- create_graph=False,
158
- retain_graph=False,
159
- only_inputs=True
160
- )[0]
161
-
162
- # Update latents
163
- with torch.no_grad():
164
- latents = latents - 0.1 * grads
165
-
166
- except Exception as e:
167
- print(f"Error in callback: {e}")
168
- return latents
169
-
170
- return latents
171
-
172
- for style_token, style_name in zip(self.style_tokens, self.style_names):
173
- styled_prompt = f"{prompt}, {style_token}"
174
- style_names.append(style_name)
175
-
176
- # Disable autocast for better gradient computation
177
- image = self.pipe(
178
  styled_prompt,
179
- num_inference_steps=num_inference_steps,
180
- guidance_scale=guidance_scale,
181
- callback=callback_fn if apply_loss else None,
182
- callback_steps=5
183
  ).images[0]
184
-
185
- images.append(image)
186
 
187
- return images, style_names
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  except Exception as e:
190
- print(f"Error during image generation: {str(e)}")
191
- print(traceback.format_exc())
192
  raise
193
 
194
  def callback_fn(self, i, t, latents):
@@ -219,38 +201,44 @@ class StyleGenerator:
219
 
220
  return latents
221
 
222
- def generate_all_variations(prompt):
223
  try:
224
  generator = StyleGenerator.get_instance()
225
  if not generator.is_initialized:
226
  generator.initialize_model()
227
 
228
- # Generate images without loss
229
- regular_images, style_names = generator.generate_images(prompt, apply_loss=False)
230
 
231
- # Generate images with loss
232
- loss_images, _ = generator.generate_images(prompt, apply_loss=True)
233
-
234
- return regular_images, loss_images, style_names
235
-
236
  except Exception as e:
237
- print(f"Error in generate_all_variations: {str(e)}")
238
- print(traceback.format_exc())
239
- raise
 
 
 
240
 
241
- def gradio_interface(prompt):
242
- try:
243
- regular_images, loss_images, style_names = generate_all_variations(prompt)
244
-
245
- return (
246
- regular_images, # Just return the images directly
247
- loss_images # Just return the images directly
248
- )
249
- except Exception as e:
250
- print(f"Error in interface: {str(e)}")
251
- print(traceback.format_exc())
252
- # Return empty lists in case of error
253
- return [], []
 
 
 
 
254
 
255
  # Create a more beautiful interface with custom styling
256
  with gr.Blocks(css="""
@@ -265,73 +253,43 @@ with gr.Blocks(css="""
265
  border: 1px solid #374151;
266
  color: #f3f4f6;
267
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  """) as iface:
269
- # Header section with dark theme
270
  gr.Markdown(
271
  """
272
- <div class="dark-theme" style="text-align: center; max-width: 800px; margin: 0 auto;">
273
  # 🎨 AI Style Transfer Studio
274
- ### Transform your ideas into artistic masterpieces with custom styles and enhanced colors
275
  </div>
276
  """
277
  )
278
 
279
- # Define the generate_single_style function first
280
- def generate_single_style(prompt, selected_style):
281
- try:
282
- generator = StyleGenerator.get_instance()
283
- if not generator.is_initialized:
284
- generator.initialize_model()
285
-
286
- # Find the index of the selected style
287
- style_idx = generator.style_names.index(generator.style_names[selected_style])
288
-
289
- # Generate single image with selected style
290
- styled_prompt = f"{prompt}, {generator.style_tokens[style_idx]}"
291
-
292
- # Set seed for reproducibility
293
- generator_seed = 42
294
- torch.manual_seed(generator_seed)
295
- if generator.device == "cuda":
296
- torch.cuda.manual_seed(generator_seed)
297
-
298
- # Generate base image
299
- with autocast(generator.device):
300
- base_image = generator.pipe(
301
- styled_prompt,
302
- num_inference_steps=50,
303
- guidance_scale=7.5,
304
- generator=torch.Generator(generator.device).manual_seed(generator_seed)
305
- ).images[0]
306
-
307
- # Generate same image with loss
308
- with autocast(generator.device):
309
- loss_image = generator.pipe(
310
- styled_prompt,
311
- num_inference_steps=50,
312
- guidance_scale=7.5,
313
- callback=generator.callback_fn,
314
- callback_steps=5,
315
- generator=torch.Generator(generator.device).manual_seed(generator_seed)
316
- ).images[0]
317
-
318
- return [
319
- gr.update(visible=False), # error_message
320
- base_image, # original_image
321
- loss_image # loss_image
322
- ]
323
- except Exception as e:
324
- print(f"Error in generate_single_style: {e}")
325
- return [
326
- gr.update(value=f"Error: {str(e)}", visible=True), # error_message
327
- None, # original_image
328
- None # loss_image
329
- ]
330
-
331
- # Main content
332
  with gr.Row():
333
- # Left sidebar for controls
334
- with gr.Column(scale=1, min_width=300):
335
  gr.Markdown("## 🎯 Controls")
336
 
337
  prompt = gr.Textbox(
@@ -359,29 +317,88 @@ with gr.Blocks(css="""
359
  size="lg"
360
  )
361
 
362
- # Error messages
363
  error_message = gr.Markdown(visible=False)
364
-
365
- # Style description
366
  style_description = gr.Markdown()
367
-
368
- # Right side for image display
369
- with gr.Column(scale=2):
370
- gr.Markdown("## πŸ–ΌοΈ Generated Artwork")
371
- with gr.Row():
372
- with gr.Column():
373
- original_image = gr.Image(
374
- label="Original Style",
375
- show_label=True,
376
- height=400
377
- )
378
- with gr.Column():
379
- loss_image = gr.Image(
380
- label="Color Enhanced",
381
- show_label=True,
382
- height=400
383
- )
384
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  # Info section
386
  with gr.Row():
387
  with gr.Column():
@@ -441,7 +458,7 @@ with gr.Blocks(css="""
441
  "Specialized in capturing the beauty of nature and wildlife"
442
  ]
443
  styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"]
444
- return f"### Selected: {styles[style_idx]}\n{descriptions[style_idx]}"
445
 
446
  style_radio.change(
447
  fn=update_style_description,
@@ -449,7 +466,6 @@ with gr.Blocks(css="""
449
  outputs=style_description
450
  )
451
 
452
- # Connect the generate button
453
  generate_btn.click(
454
  fn=generate_single_style,
455
  inputs=[prompt, style_radio],
 
6
  import os
7
  from pathlib import Path
8
  import traceback
9
+ import glob
10
+ from PIL import Image
11
 
12
  # Reuse the same load_learned_embed_in_clip and Distance_loss functions
13
  def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
 
94
  "Bird Style"
95
  ]
96
  self.is_initialized = False
 
97
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
98
  if self.device == "cpu":
99
  print("NVIDIA GPU not found. Running on CPU (this will be slower)")
 
133
  print(traceback.format_exc())
134
  raise
135
 
136
+ def generate_single_style(self, prompt, selected_style):
 
 
 
 
 
 
137
  try:
138
+ # Find the index of the selected style
139
+ style_idx = self.style_names.index(self.style_names[selected_style])
140
+
141
+ # Generate single image with selected style
142
+ styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}"
143
+
144
+ # Set seed for reproducibility
145
+ generator_seed = 42
146
+ torch.manual_seed(generator_seed)
147
+ if self.device == "cuda":
148
+ torch.cuda.manual_seed(generator_seed)
149
+
150
+ # Generate base image
151
+ with autocast(self.device):
152
+ base_image = self.pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  styled_prompt,
154
+ num_inference_steps=50,
155
+ guidance_scale=7.5,
156
+ generator=torch.Generator(self.device).manual_seed(generator_seed)
 
157
  ).images[0]
 
 
158
 
159
+ # Generate same image with loss
160
+ with autocast(self.device):
161
+ loss_image = self.pipe(
162
+ styled_prompt,
163
+ num_inference_steps=50,
164
+ guidance_scale=7.5,
165
+ callback=self.callback_fn,
166
+ callback_steps=5,
167
+ generator=torch.Generator(self.device).manual_seed(generator_seed)
168
+ ).images[0]
169
+
170
+ return base_image, loss_image
171
 
172
  except Exception as e:
173
+ print(f"Error in generate_single_style: {e}")
 
174
  raise
175
 
176
  def callback_fn(self, i, t, latents):
 
201
 
202
  return latents
203
 
204
+ def generate_single_style(prompt, selected_style):
205
  try:
206
  generator = StyleGenerator.get_instance()
207
  if not generator.is_initialized:
208
  generator.initialize_model()
209
 
210
+ base_image, loss_image = generator.generate_single_style(prompt, selected_style)
 
211
 
212
+ return [
213
+ gr.update(visible=False), # error_message
214
+ base_image, # original_image
215
+ loss_image # loss_image
216
+ ]
217
  except Exception as e:
218
+ print(f"Error in generate_single_style: {e}")
219
+ return [
220
+ gr.update(value=f"Error: {str(e)}", visible=True), # error_message
221
+ None, # original_image
222
+ None # loss_image
223
+ ]
224
 
225
+ # Add at the start of your script
226
+ def debug_image_paths():
227
+ output_dir = Path("Outputs")
228
+ enhanced_dir = output_dir / "Color_Enhanced"
229
+ print(f"\nChecking image paths:")
230
+ print(f"Current working directory: {Path.cwd()}")
231
+ print(f"Looking for images in: {enhanced_dir.absolute()}")
232
+
233
+ if enhanced_dir.exists():
234
+ print("\nFound files:")
235
+ for file in enhanced_dir.glob("*.webp"):
236
+ print(f"- {file.name}")
237
+ else:
238
+ print("\nDirectory not found!")
239
+
240
+ # Call this function before creating the interface
241
+ debug_image_paths()
242
 
243
  # Create a more beautiful interface with custom styling
244
  with gr.Blocks(css="""
 
253
  border: 1px solid #374151;
254
  color: #f3f4f6;
255
  }
256
+ /* Enhanced Tab Styling */
257
+ .tabs.svelte-710i53 {
258
+ margin-bottom: 0 !important;
259
+ }
260
+ .tab-nav.svelte-710i53 {
261
+ background: transparent !important;
262
+ border: none !important;
263
+ padding: 12px 24px !important;
264
+ margin: 0 2px !important;
265
+ color: #9CA3AF !important;
266
+ font-weight: 500 !important;
267
+ transition: all 0.2s ease !important;
268
+ border-bottom: 2px solid transparent !important;
269
+ }
270
+ .tab-nav.svelte-710i53.selected {
271
+ background: transparent !important;
272
+ color: #F3F4F6 !important;
273
+ border-bottom: 2px solid #6366F1 !important;
274
+ }
275
+ .tab-nav.svelte-710i53:hover {
276
+ color: #F3F4F6 !important;
277
+ border-bottom: 2px solid #4F46E5 !important;
278
+ }
279
  """) as iface:
280
+ # Header section
281
  gr.Markdown(
282
  """
283
+ <div class="dark-theme" style="text-align: center;">
284
  # 🎨 AI Style Transfer Studio
285
+ ### Transform your ideas into artistic masterpieces
286
  </div>
287
  """
288
  )
289
 
290
+ # Controls section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  with gr.Row():
292
+ with gr.Column(scale=1):
 
293
  gr.Markdown("## 🎯 Controls")
294
 
295
  prompt = gr.Textbox(
 
317
  size="lg"
318
  )
319
 
 
320
  error_message = gr.Markdown(visible=False)
 
 
321
  style_description = gr.Markdown()
322
+
323
+ # Generated Images
324
+ with gr.Row():
325
+ with gr.Column():
326
+ original_image = gr.Image(
327
+ label="Original Style",
328
+ show_label=True,
329
+ height=300
330
+ )
331
+ with gr.Column():
332
+ loss_image = gr.Image(
333
+ label="Color Enhanced",
334
+ show_label=True,
335
+ height=300
336
+ )
337
+
338
+ # Example Gallery
339
+ gr.Markdown(
340
+ """
341
+ <div class="dark-theme">
342
+ ## πŸŽ† Example Gallery
343
+ Compare original and enhanced versions for each style:
344
+ </div>
345
+ """
346
+ )
347
+
348
+ # Example Images
349
+ with gr.Row():
350
+ try:
351
+ output_dir = Path("Outputs")
352
+ original_dir = output_dir
353
+ enhanced_dir = output_dir / "Color_Enhanced"
354
+
355
+ if enhanced_dir.exists():
356
+ original_images = {
357
+ Path(f).stem.split('_example')[0]: f
358
+ for f in original_dir.glob("*.webp")
359
+ if '_example' in f.name
360
+ }
361
+ enhanced_images = {
362
+ Path(f).stem.split('_example')[0]: f
363
+ for f in enhanced_dir.glob("*.webp")
364
+ if '_example' in f.name
365
+ }
366
+
367
+ styles = [
368
+ ("ronaldo", "Ronaldo Style"),
369
+ ("canna_lily", "Canna Lily"),
370
+ ("three_stooges", "Three Stooges"),
371
+ ("pop_art", "Pop Art"),
372
+ ("bird_style", "Bird Style")
373
+ ]
374
+
375
+ # Create a grid of all styles
376
+ for style_key, style_name in styles:
377
+ if style_key in original_images and style_key in enhanced_images:
378
+ with gr.Row():
379
+ gr.Markdown(f"### {style_name}")
380
+ with gr.Row():
381
+ with gr.Column(scale=1):
382
+ gr.Image(
383
+ value=str(original_images[style_key]),
384
+ label="Original",
385
+ show_label=True,
386
+ height=180
387
+ )
388
+ with gr.Column(scale=1):
389
+ gr.Image(
390
+ value=str(enhanced_images[style_key]),
391
+ label="Color Enhanced",
392
+ show_label=True,
393
+ height=180
394
+ )
395
+ # Add a small spacing between styles
396
+ gr.Markdown("<div style='margin: 10px 0;'></div>")
397
+
398
+ except Exception as e:
399
+ print(f"Error in example gallery: {e}")
400
+ gr.Markdown(f"Error loading example gallery: {str(e)}")
401
+
402
  # Info section
403
  with gr.Row():
404
  with gr.Column():
 
458
  "Specialized in capturing the beauty of nature and wildlife"
459
  ]
460
  styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"]
461
+ return f"### Selected Style: {styles[style_idx]}\n{descriptions[style_idx]}"
462
 
463
  style_radio.change(
464
  fn=update_style_description,
 
466
  outputs=style_description
467
  )
468
 
 
469
  generate_btn.click(
470
  fn=generate_single_style,
471
  inputs=[prompt, style_radio],