prithivMLmods commited on
Commit
5663d15
·
verified ·
1 Parent(s): fe1e01e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -114
app.py CHANGED
@@ -22,10 +22,15 @@ subprocess.run(
22
  shell=True
23
  )
24
 
 
 
 
 
 
25
  # -------------------------------
26
  # CONFIGURATION & UTILITY FUNCTIONS
27
  # -------------------------------
28
- MAX_SEED = np.iinfo(np.int32).max
29
 
30
  def save_image(img: Image.Image) -> str:
31
  """Save a PIL image with a unique filename and return its path."""
@@ -38,79 +43,66 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
38
  seed = random.randint(0, MAX_SEED)
39
  return seed
40
 
41
- # Determine preferred torch dtype based on GPU support.
42
- bf16_supported = torch.cuda.is_bf16_supported()
43
- preferred_dtype = torch.bfloat16 if bf16_supported else torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # -------------------------------
46
- # FLUX.1 IMAGE GENERATION SETUP
47
  # -------------------------------
48
  from diffusers import DiffusionPipeline
49
 
50
  base_model = "black-forest-labs/FLUX.1-dev"
51
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=preferred_dtype)
52
- lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
53
- trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
54
  pipe.load_lora_weights(lora_repo)
55
  pipe.to("cuda")
56
 
57
- # Define style prompts for Flux.1
58
- style_list = [
59
- {
60
- "name": "3840 x 2160",
61
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
62
- },
63
- {
64
- "name": "2560 x 1440",
65
- "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
66
- },
67
- {
68
- "name": "HD+",
69
- "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
70
- },
71
- {
72
- "name": "Style Zero",
73
- "prompt": "{prompt}",
74
- },
75
- ]
76
- styles = {s["name"]: s["prompt"] for s in style_list}
77
- DEFAULT_STYLE_NAME = "3840 x 2160"
78
- STYLE_NAMES = list(styles.keys())
79
-
80
- def apply_style(style_name: str, positive: str) -> str:
81
- return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
82
-
83
- @spaces.GPU(duration=60, enable_queue=True)
84
- def generate_image_flux(
85
- prompt: str,
86
- seed: int = 0,
87
- width: int = 1024,
88
- height: int = 1024,
89
- guidance_scale: float = 3,
90
- randomize_seed: bool = False,
91
- style_name: str = DEFAULT_STYLE_NAME,
92
- progress=gr.Progress(track_tqdm=True),
93
- ):
94
- """Generate an image using the Flux.1 pipeline with a chosen style."""
95
- torch.cuda.empty_cache() # Clear unused GPU memory to prevent allocation errors
96
- seed = int(randomize_seed_fn(seed, randomize_seed))
97
- positive_prompt = apply_style(style_name, prompt)
98
- if trigger_word:
99
- positive_prompt = f"{trigger_word} {positive_prompt}"
100
- # Wrap the diffusion call in no_grad to avoid unnecessary gradient state.
101
- with torch.no_grad():
102
- images = pipe(
103
- prompt=positive_prompt,
104
- width=width,
105
- height=height,
106
- guidance_scale=guidance_scale,
107
- num_inference_steps=28,
108
- num_images_per_prompt=1,
109
- output_type="pil",
110
- ).images
111
- torch.cuda.synchronize() # Ensure all CUDA operations have completed
112
- image_paths = [save_image(img) for img in images]
113
- return image_paths, seed
114
 
115
  # -------------------------------
116
  # SMOLVLM2 SETUP (Default Text/Multimodal Model)
@@ -121,31 +113,12 @@ smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Inst
121
  smol_model = AutoModelForImageTextToText.from_pretrained(
122
  "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
123
  _attn_implementation="flash_attention_2",
124
- torch_dtype=preferred_dtype
125
  ).to("cuda:0")
126
 
127
  # -------------------------------
128
- # UTILITY FUNCTIONS
129
  # -------------------------------
130
- def progress_bar_html(label: str) -> str:
131
- """
132
- Returns an HTML snippet for an animated progress bar with a given label.
133
- """
134
- return f'''
135
- <div style="display: flex; align-items: center;">
136
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
137
- <div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
138
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
139
- </div>
140
- </div>
141
- <style>
142
- @keyframes loading {{
143
- 0% {{ transform: translateX(-100%); }}
144
- 100% {{ transform: translateX(100%); }}
145
- }}
146
- </style>
147
- '''
148
-
149
  TTS_VOICES = [
150
  "en-US-JennyNeural", # @tts1
151
  "en-US-GuyNeural", # @tts2
@@ -161,36 +134,32 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
161
  # CHAT / MULTIMODAL GENERATION FUNCTION
162
  # -------------------------------
163
  @spaces.GPU
164
- def generate(
165
- input_dict: dict,
166
- chat_history: list[dict],
167
- max_tokens: int = 200,
168
- ):
169
  """
170
- Generates chatbot responses using SmolVLM2 by default—with support for multimodal inputs and TTS.
171
  Special commands:
172
- - "@image": triggers image generation using the Flux.1 pipeline.
173
  - "@tts1" or "@tts2": triggers text-to-speech after generation.
174
  """
175
- torch.cuda.empty_cache() # Clear unused GPU memory for consistency
176
  text = input_dict["text"]
177
  files = input_dict.get("files", [])
178
 
179
- # If the query starts with "@image", use Flux.1 to generate an image.
180
  if text.strip().lower().startswith("@image"):
181
  prompt = text[len("@image"):].strip()
182
- yield progress_bar_html("Hold Tight Generating Flux.1 Image")
183
- image_paths, used_seed = generate_image_flux(
184
- prompt=prompt,
185
- seed=1,
186
- width=1024,
187
- height=1024,
188
- guidance_scale=3,
189
- randomize_seed=True,
190
- style_name=DEFAULT_STYLE_NAME,
191
- progress=gr.Progress(track_tqdm=True),
192
- )
193
- yield gr.Image(image_paths[0])
194
  return
195
 
196
  # Handle TTS commands if present.
@@ -203,7 +172,6 @@ def generate(
203
  voice = TTS_VOICES[voice_index - 1]
204
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
205
 
206
- # Use SmolVLM2 for chat/multimodal text generation.
207
  yield "Processing with SmolVLM2"
208
 
209
  # Build conversation messages based on input and history.
@@ -272,7 +240,6 @@ def generate(
272
  yield "Please input a text query along with the image(s)."
273
  return
274
 
275
- print("resulting_messages", resulting_messages)
276
  inputs = smol_processor.apply_chat_template(
277
  resulting_messages,
278
  add_generation_prompt=True,
@@ -280,9 +247,8 @@ def generate(
280
  return_dict=True,
281
  return_tensors="pt",
282
  )
283
- # Explicitly cast pixel values to the preferred dtype to match model weights.
284
  if "pixel_values" in inputs:
285
- inputs["pixel_values"] = inputs["pixel_values"].to(preferred_dtype)
286
  inputs = inputs.to(smol_model.device)
287
 
288
  streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
@@ -305,7 +271,7 @@ def generate(
305
  # -------------------------------
306
  # GRADIO CHAT INTERFACE
307
  # -------------------------------
308
- DESCRIPTION = "# Flux.1 Realism 🥖 + SmolVLM2 Chat"
309
  if not torch.cuda.is_available():
310
  DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>"
311
 
@@ -328,7 +294,7 @@ demo = gr.ChatInterface(
328
  gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"),
329
  ],
330
  examples=[
331
- [{"text": "@image A futuristic cityscape at dusk in hyper-realistic 8K"}],
332
  [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
333
  [{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
334
  [{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
@@ -342,7 +308,7 @@ demo = gr.ChatInterface(
342
  label="Query Input",
343
  file_types=["image", ".mp4"],
344
  file_count="multiple",
345
- placeholder="Type text and/or upload media. Use '@image' for Flux.1 image gen, '@tts1' or '@tts2' for TTS."
346
  ),
347
  stop_btn="Stop Generation",
348
  multimodal=True,
 
22
  shell=True
23
  )
24
 
25
+ # Set torch backend configurations for Flux RealismLora
26
+ torch.backends.cudnn.deterministic = True
27
+ torch.backends.cudnn.benchmark = False
28
+ torch.backends.cuda.matmul.allow_tf32 = True
29
+
30
  # -------------------------------
31
  # CONFIGURATION & UTILITY FUNCTIONS
32
  # -------------------------------
33
+ MAX_SEED = 2**32 - 1
34
 
35
  def save_image(img: Image.Image) -> str:
36
  """Save a PIL image with a unique filename and return its path."""
 
43
  seed = random.randint(0, MAX_SEED)
44
  return seed
45
 
46
+ def progress_bar_html(label: str) -> str:
47
+ """
48
+ Returns an HTML snippet for an animated progress bar with a given label.
49
+ """
50
+ return f'''
51
+ <div style="display: flex; align-items: center;">
52
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
53
+ <div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
54
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
55
+ </div>
56
+ </div>
57
+ <style>
58
+ @keyframes loading {{
59
+ 0% {{ transform: translateX(-100%); }}
60
+ 100% {{ transform: translateX(100%); }}
61
+ }}
62
+ </style>
63
+ '''
64
 
65
  # -------------------------------
66
+ # FLUX REALISMLORA IMAGE GENERATION SETUP (New Implementation)
67
  # -------------------------------
68
  from diffusers import DiffusionPipeline
69
 
70
  base_model = "black-forest-labs/FLUX.1-dev"
71
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
72
+ lora_repo = "XLabs-AI/flux-RealismLora"
73
+ trigger_word = "" # No trigger word used.
74
  pipe.load_lora_weights(lora_repo)
75
  pipe.to("cuda")
76
 
77
+ @spaces.GPU()
78
+ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
79
+ # Set random seed for reproducibility
80
+ if randomize_seed:
81
+ seed = random.randint(0, MAX_SEED)
82
+ generator = torch.Generator(device="cuda").manual_seed(seed)
83
+
84
+ # Update progress bar (0% at start)
85
+ progress(0, "Starting image generation...")
86
+
87
+ # Simulate progress updates during the steps
88
+ for i in range(1, steps + 1):
89
+ if steps >= 10 and i % (steps // 10) == 0:
90
+ progress(i / steps * 100, f"Processing step {i} of {steps}...")
91
+
92
+ # Generate image using the pipeline
93
+ image = pipe(
94
+ prompt=f"{prompt} {trigger_word}",
95
+ num_inference_steps=steps,
96
+ guidance_scale=cfg_scale,
97
+ width=width,
98
+ height=height,
99
+ generator=generator,
100
+ joint_attention_kwargs={"scale": lora_scale},
101
+ ).images[0]
102
+
103
+ # Final progress update (100%)
104
+ progress(100, "Completed!")
105
+ yield image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  # -------------------------------
108
  # SMOLVLM2 SETUP (Default Text/Multimodal Model)
 
113
  smol_model = AutoModelForImageTextToText.from_pretrained(
114
  "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
115
  _attn_implementation="flash_attention_2",
116
+ torch_dtype=torch.float16
117
  ).to("cuda:0")
118
 
119
  # -------------------------------
120
+ # TTS UTILITY FUNCTIONS
121
  # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  TTS_VOICES = [
123
  "en-US-JennyNeural", # @tts1
124
  "en-US-GuyNeural", # @tts2
 
134
  # CHAT / MULTIMODAL GENERATION FUNCTION
135
  # -------------------------------
136
  @spaces.GPU
137
+ def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200):
 
 
 
 
138
  """
139
+ Generates chatbot responses using SmolVLM2 with support for multimodal inputs and TTS.
140
  Special commands:
141
+ - "@image": triggers image generation using the RealismLora flux implementation.
142
  - "@tts1" or "@tts2": triggers text-to-speech after generation.
143
  """
144
+ torch.cuda.empty_cache()
145
  text = input_dict["text"]
146
  files = input_dict.get("files", [])
147
 
148
+ # If the query starts with "@image", use RealismLora to generate an image.
149
  if text.strip().lower().startswith("@image"):
150
  prompt = text[len("@image"):].strip()
151
+ yield progress_bar_html("Hold Tight Generating Flux RealismLora Image")
152
+ # Default parameters for RealismLora generation
153
+ default_cfg_scale = 3.2
154
+ default_steps = 32
155
+ default_width = 1152
156
+ default_height = 896
157
+ default_seed = 3981632454
158
+ default_lora_scale = 0.85
159
+ # Call the new run_lora function and yield its final result
160
+ for result in run_lora(prompt, default_cfg_scale, default_steps, True, default_seed, default_width, default_height, default_lora_scale, progress=gr.Progress(track_tqdm=True)):
161
+ final_result = result
162
+ yield gr.Image(final_result[0])
163
  return
164
 
165
  # Handle TTS commands if present.
 
172
  voice = TTS_VOICES[voice_index - 1]
173
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
174
 
 
175
  yield "Processing with SmolVLM2"
176
 
177
  # Build conversation messages based on input and history.
 
240
  yield "Please input a text query along with the image(s)."
241
  return
242
 
 
243
  inputs = smol_processor.apply_chat_template(
244
  resulting_messages,
245
  add_generation_prompt=True,
 
247
  return_dict=True,
248
  return_tensors="pt",
249
  )
 
250
  if "pixel_values" in inputs:
251
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
252
  inputs = inputs.to(smol_model.device)
253
 
254
  streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
 
271
  # -------------------------------
272
  # GRADIO CHAT INTERFACE
273
  # -------------------------------
274
+ DESCRIPTION = "# Flux RealismLora + SmolVLM2 Chat"
275
  if not torch.cuda.is_available():
276
  DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>"
277
 
 
294
  gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"),
295
  ],
296
  examples=[
297
+ [{"text": "@image A futuristic cityscape at dusk in hyper-realistic style"}],
298
  [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
299
  [{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
300
  [{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
 
308
  label="Query Input",
309
  file_types=["image", ".mp4"],
310
  file_count="multiple",
311
+ placeholder="Type text and/or upload media. Use '@image' for image gen, '@tts1' or '@tts2' for TTS."
312
  ),
313
  stop_btn="Stop Generation",
314
  multimodal=True,