prithivMLmods commited on
Commit
cbc2b17
·
verified ·
1 Parent(s): 0f33c32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -196
app.py CHANGED
@@ -4,6 +4,7 @@ import uuid
4
  import json
5
  import time
6
  import asyncio
 
7
  from threading import Thread
8
 
9
  import gradio as gr
@@ -12,106 +13,22 @@ import torch
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
 
15
 
16
- from transformers import (
17
- AutoModelForCausalLM,
18
- AutoTokenizer,
19
- TextIteratorStreamer,
20
- Qwen2_5_VLForConditionalGeneration,
21
- AutoProcessor,
22
  )
23
- from transformers.image_utils import load_image
24
- from diffusers import DiffusionPipeline
25
-
26
- DESCRIPTION = "# Flux.1 Realism 🥖"
27
- if not torch.cuda.is_available():
28
- DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
29
-
30
- css = '''
31
- h1 {
32
- text-align: center;
33
- display: block;
34
- }
35
-
36
- #duplicate-button {
37
- margin: auto;
38
- color: #fff;
39
- background: #1565c0;
40
- border-radius: 100vh;
41
- }
42
- '''
43
-
44
- MAX_MAX_NEW_TOKENS = 2048
45
- DEFAULT_MAX_NEW_TOKENS = 1024
46
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
47
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48
-
49
- # Load text-only model and tokenizer
50
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
51
- tokenizer = AutoTokenizer.from_pretrained(model_id)
52
- model = AutoModelForCausalLM.from_pretrained(
53
- model_id,
54
- device_map="auto",
55
- torch_dtype=torch.bfloat16,
56
- )
57
- model.eval()
58
-
59
- TTS_VOICES = [
60
- "en-US-JennyNeural", # @tts1
61
- "en-US-GuyNeural", # @tts2
62
- ]
63
-
64
- # Load multimodal Qwen model & processor
65
- MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
66
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
67
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
68
- MODEL_ID,
69
- trust_remote_code=True,
70
- torch_dtype=torch.float16
71
- ).to("cuda").eval()
72
-
73
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
74
- """Convert text to speech using Edge TTS and save as MP3"""
75
- communicate = edge_tts.Communicate(text, voice)
76
- await communicate.save(output_file)
77
- return output_file
78
-
79
- def clean_chat_history(chat_history):
80
- """
81
- Filter out any chat entries whose "content" is not a string.
82
- This helps prevent errors when concatenating previous messages.
83
- """
84
- cleaned = []
85
- for msg in chat_history:
86
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
87
- cleaned.append(msg)
88
- return cleaned
89
-
90
- def progress_bar_html(label: str) -> str:
91
- """
92
- Returns an HTML snippet for a thin progress bar with a label.
93
- The progress bar is styled as a dark red animated bar.
94
- """
95
- return f'''
96
- <div style="display: flex; align-items: center;">
97
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
98
- <div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
99
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
100
- </div>
101
- </div>
102
- <style>
103
- @keyframes loading {{
104
- 0% {{ transform: translateX(-100%); }}
105
- 100% {{ transform: translateX(100%); }}
106
- }}
107
- </style>
108
- '''
109
 
 
110
  # FLUX.1 IMAGE GENERATION SETUP
 
111
  MAX_SEED = np.iinfo(np.int32).max
112
 
113
  def save_image(img: Image.Image) -> str:
114
- """Save a PIL image with a unique filename and return the path."""
115
  unique_name = str(uuid.uuid4()) + ".png"
116
  img.save(unique_name)
117
  return unique_name
@@ -121,7 +38,9 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
121
  seed = random.randint(0, MAX_SEED)
122
  return seed
123
 
124
- # Initialize Flux.1 pipeline
 
 
125
  base_model = "black-forest-labs/FLUX.1-dev"
126
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
127
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
@@ -148,14 +67,14 @@ style_list = [
148
  "prompt": "{prompt}",
149
  },
150
  ]
151
- styles = {k["name"]: k["prompt"] for k in style_list}
152
  DEFAULT_STYLE_NAME = "3840 x 2160"
153
  STYLE_NAMES = list(styles.keys())
154
 
155
  def apply_style(style_name: str, positive: str) -> str:
156
  return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
157
 
158
- @spaces.GPU
159
  def generate_image_flux(
160
  prompt: str,
161
  seed: int = 0,
@@ -166,7 +85,7 @@ def generate_image_flux(
166
  style_name: str = DEFAULT_STYLE_NAME,
167
  progress=gr.Progress(track_tqdm=True),
168
  ):
169
- """Generate images using the Flux.1 pipeline with style prompts."""
170
  seed = int(randomize_seed_fn(seed, randomize_seed))
171
  positive_prompt = apply_style(style_name, prompt)
172
  if trigger_word:
@@ -183,29 +102,72 @@ def generate_image_flux(
183
  image_paths = [save_image(img) for img in images]
184
  return image_paths, seed
185
 
186
- # CHAT GENERATION FUNCTION (TEXT & MULTIMODAL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  @spaces.GPU
188
  def generate(
189
  input_dict: dict,
190
  chat_history: list[dict],
191
- max_new_tokens: int = 1024,
192
- temperature: float = 0.6,
193
- top_p: float = 0.9,
194
- top_k: int = 50,
195
- repetition_penalty: float = 1.2,
196
  ):
197
  """
198
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
199
  Special commands:
200
- - "@tts1" or "@tts2": triggers text-to-speech.
201
  - "@image": triggers image generation using the Flux.1 pipeline.
 
202
  """
203
  text = input_dict["text"]
204
  files = input_dict.get("files", [])
205
 
206
- # If the text begins with "@image", use Flux for image generation.
207
  if text.strip().lower().startswith("@image"):
208
- # Remove the "@image" tag and use the remainder as the prompt.
209
  prompt = text[len("@image"):].strip()
210
  yield progress_bar_html("Hold Tight Generating Flux.1 Image")
211
  image_paths, used_seed = generate_image_flux(
@@ -219,109 +181,144 @@ def generate(
219
  progress=gr.Progress(track_tqdm=True),
220
  )
221
  yield gr.Image(image_paths[0])
222
- return # Exit early after image generation.
223
 
224
- # Check if a TTS command is issued.
225
  tts_prefix = "@tts"
226
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
227
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
228
-
229
- if is_tts and voice_index:
230
- voice = TTS_VOICES[voice_index - 1]
231
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
232
- # Clear previous chat history for a fresh TTS request.
233
- conversation = [{"role": "user", "content": text}]
234
- else:
235
- voice = None
236
- # Remove any stray @tts tags and build the conversation history.
237
- text = text.replace(tts_prefix, "").strip()
238
- conversation = clean_chat_history(chat_history)
239
- conversation.append({"role": "user", "content": text})
240
 
241
- # Handle multimodal input if files are provided.
242
- if files:
243
- if len(files) > 1:
244
- images = [load_image(image) for image in files]
245
- elif len(files) == 1:
246
- images = [load_image(files[0])]
247
- else:
248
- images = []
249
- messages = [{
250
- "role": "user",
251
- "content": [
252
- *[{"type": "image", "image": image} for image in images],
253
- {"type": "text", "text": text},
254
- ]
255
- }]
256
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
257
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
258
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
259
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
260
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
261
- thread.start()
262
 
263
- buffer = ""
264
- yield progress_bar_html("Thinking...")
265
- for new_text in streamer:
266
- buffer += new_text
267
- buffer = buffer.replace("<|im_end|>", "")
268
- time.sleep(0.01)
269
- yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  else:
271
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
272
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
273
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
274
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
275
- input_ids = input_ids.to(model.device)
276
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
277
- generation_kwargs = {
278
- "input_ids": input_ids,
279
- "streamer": streamer,
280
- "max_new_tokens": max_new_tokens,
281
- "do_sample": True,
282
- "top_p": top_p,
283
- "top_k": top_k,
284
- "temperature": temperature,
285
- "num_beams": 1,
286
- "repetition_penalty": repetition_penalty,
287
- }
288
- t = Thread(target=model.generate, kwargs=generation_kwargs)
289
- t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- outputs = []
292
- yield progress_bar_html("Thinking...")
293
- for new_text in streamer:
294
- outputs.append(new_text)
295
- yield "".join(outputs)
 
 
 
 
296
 
297
- final_response = "".join(outputs)
298
- yield final_response
 
 
299
 
300
- # If TTS was requested, convert the final response to speech.
301
- if is_tts and voice:
302
- output_file = asyncio.run(text_to_speech(final_response, voice))
303
- yield gr.Audio(output_file, autoplay=True)
 
 
304
 
 
 
 
 
 
 
305
  # GRADIO CHAT INTERFACE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  demo = gr.ChatInterface(
307
  fn=generate,
308
  additional_inputs=[
309
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
310
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
311
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
312
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
313
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
314
  ],
315
  examples=[
316
- [{"text": "@image Chocolate dripping from a donut against a yellow background, in the style of hyper-realistic 8K"}],
317
- [{"text": "@image Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250"}],
318
- [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
319
- [{"text": "@image Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6"}],
320
- [{"text": "@image Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."}],
321
- [{"text": "Python Program for Array Rotation"}],
322
- [{"text": "@tts1 Who is Nikola Tesla, and why did he die?"}],
323
- [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
324
- [{"text": "@tts2 What causes rainbows to form?"}],
325
  ],
326
  cache_examples=False,
327
  type="messages",
@@ -330,9 +327,9 @@ demo = gr.ChatInterface(
330
  fill_height=True,
331
  textbox=gr.MultimodalTextbox(
332
  label="Query Input",
333
- file_types=["image"],
334
  file_count="multiple",
335
- placeholder=" @image-flux.1 image gen, @tts1, @tts2-voices, default [text, vision]"
336
  ),
337
  stop_btn="Stop Generation",
338
  multimodal=True,
 
4
  import json
5
  import time
6
  import asyncio
7
+ import re
8
  from threading import Thread
9
 
10
  import gradio as gr
 
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
16
+ import subprocess
17
 
18
+ # Install flash-attn with our environment flag (if needed)
19
+ subprocess.run(
20
+ 'pip install flash-attn --no-build-isolation',
21
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
22
+ shell=True
 
23
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # -------------------------------
26
  # FLUX.1 IMAGE GENERATION SETUP
27
+ # -------------------------------
28
  MAX_SEED = np.iinfo(np.int32).max
29
 
30
  def save_image(img: Image.Image) -> str:
31
+ """Save a PIL image with a unique filename and return its path."""
32
  unique_name = str(uuid.uuid4()) + ".png"
33
  img.save(unique_name)
34
  return unique_name
 
38
  seed = random.randint(0, MAX_SEED)
39
  return seed
40
 
41
+ # Load Flux.1 pipeline and LoRA weights
42
+ from diffusers import DiffusionPipeline
43
+
44
  base_model = "black-forest-labs/FLUX.1-dev"
45
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
46
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
 
67
  "prompt": "{prompt}",
68
  },
69
  ]
70
+ styles = {s["name"]: s["prompt"] for s in style_list}
71
  DEFAULT_STYLE_NAME = "3840 x 2160"
72
  STYLE_NAMES = list(styles.keys())
73
 
74
  def apply_style(style_name: str, positive: str) -> str:
75
  return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
76
 
77
+ @spaces.GPU(duration=60, enable_queue=True)
78
  def generate_image_flux(
79
  prompt: str,
80
  seed: int = 0,
 
85
  style_name: str = DEFAULT_STYLE_NAME,
86
  progress=gr.Progress(track_tqdm=True),
87
  ):
88
+ """Generate an image using the Flux.1 pipeline with a chosen style."""
89
  seed = int(randomize_seed_fn(seed, randomize_seed))
90
  positive_prompt = apply_style(style_name, prompt)
91
  if trigger_word:
 
102
  image_paths = [save_image(img) for img in images]
103
  return image_paths, seed
104
 
105
+ # -------------------------------
106
+ # SMOLVLM2 SETUP (Default Text/Multimodal Model)
107
+ # -------------------------------
108
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
109
+ # Load the SmolVLM2 processor and model with flash attention enabled.
110
+ smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
111
+ smol_model = AutoModelForImageTextToText.from_pretrained(
112
+ "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
113
+ _attn_implementation="flash_attention_2",
114
+ torch_dtype=torch.bfloat16
115
+ ).to("cuda:0")
116
+
117
+ # -------------------------------
118
+ # UTILITY FUNCTIONS
119
+ # -------------------------------
120
+ def progress_bar_html(label: str) -> str:
121
+ """
122
+ Returns an HTML snippet for an animated progress bar with a given label.
123
+ """
124
+ return f'''
125
+ <div style="display: flex; align-items: center;">
126
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
127
+ <div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
128
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
129
+ </div>
130
+ </div>
131
+ <style>
132
+ @keyframes loading {{
133
+ 0% {{ transform: translateX(-100%); }}
134
+ 100% {{ transform: translateX(100%); }}
135
+ }}
136
+ </style>
137
+ '''
138
+
139
+ # TTS voices (if using TTS commands)
140
+ TTS_VOICES = [
141
+ "en-US-JennyNeural", # @tts1
142
+ "en-US-GuyNeural", # @tts2
143
+ ]
144
+
145
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
146
+ """Convert text to speech using Edge TTS and save the output as MP3."""
147
+ communicate = edge_tts.Communicate(text, voice)
148
+ await communicate.save(output_file)
149
+ return output_file
150
+
151
+ # -------------------------------
152
+ # CHAT / MULTIMODAL GENERATION FUNCTION
153
+ # -------------------------------
154
  @spaces.GPU
155
  def generate(
156
  input_dict: dict,
157
  chat_history: list[dict],
158
+ max_tokens: int = 200,
 
 
 
 
159
  ):
160
  """
161
+ Generates chatbot responses using SmolVLM2 by default—with support for multimodal inputs and TTS.
162
  Special commands:
 
163
  - "@image": triggers image generation using the Flux.1 pipeline.
164
+ - "@tts1" or "@tts2": triggers text-to-speech after generation.
165
  """
166
  text = input_dict["text"]
167
  files = input_dict.get("files", [])
168
 
169
+ # If the query starts with "@image", use Flux.1 to generate an image.
170
  if text.strip().lower().startswith("@image"):
 
171
  prompt = text[len("@image"):].strip()
172
  yield progress_bar_html("Hold Tight Generating Flux.1 Image")
173
  image_paths, used_seed = generate_image_flux(
 
181
  progress=gr.Progress(track_tqdm=True),
182
  )
183
  yield gr.Image(image_paths[0])
184
+ return
185
 
186
+ # Handle TTS commands if present.
187
  tts_prefix = "@tts"
188
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
189
+ voice = None
190
+ if is_tts:
191
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
192
+ if voice_index:
193
+ voice = TTS_VOICES[voice_index - 1]
194
+ text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
 
 
 
 
 
 
195
 
196
+ # Now use SmolVLM2 for chat/multimodal text generation.
197
+ yield "Processing with SmolVLM2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Build conversation messages based on input and history.
200
+ user_content = []
201
+ media_queue = []
202
+ if chat_history == []:
203
+ text = text.strip()
204
+ for file in files:
205
+ if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
206
+ media_queue.append({"type": "image", "path": file})
207
+ elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
208
+ media_queue.append({"type": "video", "path": file})
209
+ if "<image>" in text or "<video>" in text:
210
+ parts = re.split(r'(<image>|<video>)', text)
211
+ for part in parts:
212
+ if part == "<image>" and media_queue:
213
+ user_content.append(media_queue.pop(0))
214
+ elif part == "<video>" and media_queue:
215
+ user_content.append(media_queue.pop(0))
216
+ elif part.strip():
217
+ user_content.append({"type": "text", "text": part.strip()})
218
+ else:
219
+ user_content.append({"type": "text", "text": text})
220
+ for media in media_queue:
221
+ user_content.append(media)
222
+ resulting_messages = [{"role": "user", "content": user_content}]
223
  else:
224
+ resulting_messages = []
225
+ user_content = []
226
+ media_queue = []
227
+ for hist in chat_history:
228
+ if hist["role"] == "user" and isinstance(hist["content"], tuple):
229
+ file_name = hist["content"][0]
230
+ if file_name.endswith((".png", ".jpg", ".jpeg")):
231
+ media_queue.append({"type": "image", "path": file_name})
232
+ elif file_name.endswith(".mp4"):
233
+ media_queue.append({"type": "video", "path": file_name})
234
+ for hist in chat_history:
235
+ if hist["role"] == "user" and isinstance(hist["content"], str):
236
+ txt = hist["content"]
237
+ parts = re.split(r'(<image>|<video>)', txt)
238
+ for part in parts:
239
+ if part == "<image>" and media_queue:
240
+ user_content.append(media_queue.pop(0))
241
+ elif part == "<video>" and media_queue:
242
+ user_content.append(media_queue.pop(0))
243
+ elif part.strip():
244
+ user_content.append({"type": "text", "text": part.strip()})
245
+ elif hist["role"] == "assistant":
246
+ resulting_messages.append({
247
+ "role": "user",
248
+ "content": user_content
249
+ })
250
+ resulting_messages.append({
251
+ "role": "assistant",
252
+ "content": [{"type": "text", "text": hist["content"]}]
253
+ })
254
+ user_content = []
255
+ if not resulting_messages:
256
+ resulting_messages = [{"role": "user", "content": user_content}]
257
+
258
+ if text == "" and not files:
259
+ yield "Please input a query and optionally image(s)."
260
+ return
261
+ if text == "" and files:
262
+ yield "Please input a text query along with the image(s)."
263
+ return
264
 
265
+ print("resulting_messages", resulting_messages)
266
+ inputs = smol_processor.apply_chat_template(
267
+ resulting_messages,
268
+ add_generation_prompt=True,
269
+ tokenize=True,
270
+ return_dict=True,
271
+ return_tensors="pt",
272
+ )
273
+ inputs = inputs.to(smol_model.device)
274
 
275
+ streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
276
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
277
+ thread = Thread(target=smol_model.generate, kwargs=generation_args)
278
+ thread.start()
279
 
280
+ yield "..."
281
+ buffer = ""
282
+ for new_text in streamer:
283
+ buffer += new_text
284
+ time.sleep(0.01)
285
+ yield buffer
286
 
287
+ if is_tts and voice:
288
+ final_response = buffer
289
+ output_file = asyncio.run(text_to_speech(final_response, voice))
290
+ yield gr.Audio(output_file, autoplay=True)
291
+
292
+ # -------------------------------
293
  # GRADIO CHAT INTERFACE
294
+ # -------------------------------
295
+ DESCRIPTION = "# Flux.1 Realism 🥖 + SmolVLM2 Chat"
296
+ if not torch.cuda.is_available():
297
+ DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>"
298
+
299
+ css = '''
300
+ h1 {
301
+ text-align: center;
302
+ display: block;
303
+ }
304
+ #duplicate-button {
305
+ margin: auto;
306
+ color: #fff;
307
+ background: #1565c0;
308
+ border-radius: 100vh;
309
+ }
310
+ '''
311
+
312
  demo = gr.ChatInterface(
313
  fn=generate,
314
  additional_inputs=[
315
+ gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"),
 
 
 
 
316
  ],
317
  examples=[
318
+ [{"text": "@image A futuristic cityscape at dusk in hyper-realistic 8K"}],
319
+ [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
320
+ [{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
321
+ [{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
 
 
 
 
 
322
  ],
323
  cache_examples=False,
324
  type="messages",
 
327
  fill_height=True,
328
  textbox=gr.MultimodalTextbox(
329
  label="Query Input",
330
+ file_types=["image", ".mp4"],
331
  file_count="multiple",
332
+ placeholder="Type text and/or upload media. Use '@image' for Flux.1 image gen, '@tts1' or '@tts2' for TTS."
333
  ),
334
  stop_btn="Stop Generation",
335
  multimodal=True,