Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import uuid
|
|
4 |
import json
|
5 |
import time
|
6 |
import asyncio
|
|
|
7 |
from threading import Thread
|
8 |
|
9 |
import gradio as gr
|
@@ -12,106 +13,22 @@ import torch
|
|
12 |
import numpy as np
|
13 |
from PIL import Image
|
14 |
import edge_tts
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
AutoProcessor,
|
22 |
)
|
23 |
-
from transformers.image_utils import load_image
|
24 |
-
from diffusers import DiffusionPipeline
|
25 |
-
|
26 |
-
DESCRIPTION = "# Flux.1 Realism 🥖"
|
27 |
-
if not torch.cuda.is_available():
|
28 |
-
DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
|
29 |
-
|
30 |
-
css = '''
|
31 |
-
h1 {
|
32 |
-
text-align: center;
|
33 |
-
display: block;
|
34 |
-
}
|
35 |
-
|
36 |
-
#duplicate-button {
|
37 |
-
margin: auto;
|
38 |
-
color: #fff;
|
39 |
-
background: #1565c0;
|
40 |
-
border-radius: 100vh;
|
41 |
-
}
|
42 |
-
'''
|
43 |
-
|
44 |
-
MAX_MAX_NEW_TOKENS = 2048
|
45 |
-
DEFAULT_MAX_NEW_TOKENS = 1024
|
46 |
-
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
47 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
48 |
-
|
49 |
-
# Load text-only model and tokenizer
|
50 |
-
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
51 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
52 |
-
model = AutoModelForCausalLM.from_pretrained(
|
53 |
-
model_id,
|
54 |
-
device_map="auto",
|
55 |
-
torch_dtype=torch.bfloat16,
|
56 |
-
)
|
57 |
-
model.eval()
|
58 |
-
|
59 |
-
TTS_VOICES = [
|
60 |
-
"en-US-JennyNeural", # @tts1
|
61 |
-
"en-US-GuyNeural", # @tts2
|
62 |
-
]
|
63 |
-
|
64 |
-
# Load multimodal Qwen model & processor
|
65 |
-
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
66 |
-
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
67 |
-
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
68 |
-
MODEL_ID,
|
69 |
-
trust_remote_code=True,
|
70 |
-
torch_dtype=torch.float16
|
71 |
-
).to("cuda").eval()
|
72 |
-
|
73 |
-
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
74 |
-
"""Convert text to speech using Edge TTS and save as MP3"""
|
75 |
-
communicate = edge_tts.Communicate(text, voice)
|
76 |
-
await communicate.save(output_file)
|
77 |
-
return output_file
|
78 |
-
|
79 |
-
def clean_chat_history(chat_history):
|
80 |
-
"""
|
81 |
-
Filter out any chat entries whose "content" is not a string.
|
82 |
-
This helps prevent errors when concatenating previous messages.
|
83 |
-
"""
|
84 |
-
cleaned = []
|
85 |
-
for msg in chat_history:
|
86 |
-
if isinstance(msg, dict) and isinstance(msg.get("content"), str):
|
87 |
-
cleaned.append(msg)
|
88 |
-
return cleaned
|
89 |
-
|
90 |
-
def progress_bar_html(label: str) -> str:
|
91 |
-
"""
|
92 |
-
Returns an HTML snippet for a thin progress bar with a label.
|
93 |
-
The progress bar is styled as a dark red animated bar.
|
94 |
-
"""
|
95 |
-
return f'''
|
96 |
-
<div style="display: flex; align-items: center;">
|
97 |
-
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
|
98 |
-
<div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
|
99 |
-
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
|
100 |
-
</div>
|
101 |
-
</div>
|
102 |
-
<style>
|
103 |
-
@keyframes loading {{
|
104 |
-
0% {{ transform: translateX(-100%); }}
|
105 |
-
100% {{ transform: translateX(100%); }}
|
106 |
-
}}
|
107 |
-
</style>
|
108 |
-
'''
|
109 |
|
|
|
110 |
# FLUX.1 IMAGE GENERATION SETUP
|
|
|
111 |
MAX_SEED = np.iinfo(np.int32).max
|
112 |
|
113 |
def save_image(img: Image.Image) -> str:
|
114 |
-
"""Save a PIL image with a unique filename and return
|
115 |
unique_name = str(uuid.uuid4()) + ".png"
|
116 |
img.save(unique_name)
|
117 |
return unique_name
|
@@ -121,7 +38,9 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
121 |
seed = random.randint(0, MAX_SEED)
|
122 |
return seed
|
123 |
|
124 |
-
#
|
|
|
|
|
125 |
base_model = "black-forest-labs/FLUX.1-dev"
|
126 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
127 |
lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
|
@@ -148,14 +67,14 @@ style_list = [
|
|
148 |
"prompt": "{prompt}",
|
149 |
},
|
150 |
]
|
151 |
-
styles = {
|
152 |
DEFAULT_STYLE_NAME = "3840 x 2160"
|
153 |
STYLE_NAMES = list(styles.keys())
|
154 |
|
155 |
def apply_style(style_name: str, positive: str) -> str:
|
156 |
return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
|
157 |
|
158 |
-
@spaces.GPU
|
159 |
def generate_image_flux(
|
160 |
prompt: str,
|
161 |
seed: int = 0,
|
@@ -166,7 +85,7 @@ def generate_image_flux(
|
|
166 |
style_name: str = DEFAULT_STYLE_NAME,
|
167 |
progress=gr.Progress(track_tqdm=True),
|
168 |
):
|
169 |
-
"""Generate
|
170 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
171 |
positive_prompt = apply_style(style_name, prompt)
|
172 |
if trigger_word:
|
@@ -183,29 +102,72 @@ def generate_image_flux(
|
|
183 |
image_paths = [save_image(img) for img in images]
|
184 |
return image_paths, seed
|
185 |
|
186 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
@spaces.GPU
|
188 |
def generate(
|
189 |
input_dict: dict,
|
190 |
chat_history: list[dict],
|
191 |
-
|
192 |
-
temperature: float = 0.6,
|
193 |
-
top_p: float = 0.9,
|
194 |
-
top_k: int = 50,
|
195 |
-
repetition_penalty: float = 1.2,
|
196 |
):
|
197 |
"""
|
198 |
-
Generates chatbot responses with support for multimodal
|
199 |
Special commands:
|
200 |
-
- "@tts1" or "@tts2": triggers text-to-speech.
|
201 |
- "@image": triggers image generation using the Flux.1 pipeline.
|
|
|
202 |
"""
|
203 |
text = input_dict["text"]
|
204 |
files = input_dict.get("files", [])
|
205 |
|
206 |
-
# If the
|
207 |
if text.strip().lower().startswith("@image"):
|
208 |
-
# Remove the "@image" tag and use the remainder as the prompt.
|
209 |
prompt = text[len("@image"):].strip()
|
210 |
yield progress_bar_html("Hold Tight Generating Flux.1 Image")
|
211 |
image_paths, used_seed = generate_image_flux(
|
@@ -219,109 +181,144 @@ def generate(
|
|
219 |
progress=gr.Progress(track_tqdm=True),
|
220 |
)
|
221 |
yield gr.Image(image_paths[0])
|
222 |
-
return
|
223 |
|
224 |
-
#
|
225 |
tts_prefix = "@tts"
|
226 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
conversation = [{"role": "user", "content": text}]
|
234 |
-
else:
|
235 |
-
voice = None
|
236 |
-
# Remove any stray @tts tags and build the conversation history.
|
237 |
-
text = text.replace(tts_prefix, "").strip()
|
238 |
-
conversation = clean_chat_history(chat_history)
|
239 |
-
conversation.append({"role": "user", "content": text})
|
240 |
|
241 |
-
#
|
242 |
-
|
243 |
-
if len(files) > 1:
|
244 |
-
images = [load_image(image) for image in files]
|
245 |
-
elif len(files) == 1:
|
246 |
-
images = [load_image(files[0])]
|
247 |
-
else:
|
248 |
-
images = []
|
249 |
-
messages = [{
|
250 |
-
"role": "user",
|
251 |
-
"content": [
|
252 |
-
*[{"type": "image", "image": image} for image in images],
|
253 |
-
{"type": "text", "text": text},
|
254 |
-
]
|
255 |
-
}]
|
256 |
-
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
257 |
-
inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
|
258 |
-
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
259 |
-
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
|
260 |
-
thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
|
261 |
-
thread.start()
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
else:
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
"
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
|
298 |
-
|
|
|
|
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
|
|
|
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
# GRADIO CHAT INTERFACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
demo = gr.ChatInterface(
|
307 |
fn=generate,
|
308 |
additional_inputs=[
|
309 |
-
gr.Slider(
|
310 |
-
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
|
311 |
-
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
|
312 |
-
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
|
313 |
-
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
|
314 |
],
|
315 |
examples=[
|
316 |
-
[{"text": "@image
|
317 |
-
[{"text": "
|
318 |
-
[{"text": "
|
319 |
-
[{"text": "@
|
320 |
-
[{"text": "@image Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."}],
|
321 |
-
[{"text": "Python Program for Array Rotation"}],
|
322 |
-
[{"text": "@tts1 Who is Nikola Tesla, and why did he die?"}],
|
323 |
-
[{"text": "Summarize the letter", "files": ["examples/1.png"]}],
|
324 |
-
[{"text": "@tts2 What causes rainbows to form?"}],
|
325 |
],
|
326 |
cache_examples=False,
|
327 |
type="messages",
|
@@ -330,9 +327,9 @@ demo = gr.ChatInterface(
|
|
330 |
fill_height=True,
|
331 |
textbox=gr.MultimodalTextbox(
|
332 |
label="Query Input",
|
333 |
-
file_types=["image"],
|
334 |
file_count="multiple",
|
335 |
-
placeholder="
|
336 |
),
|
337 |
stop_btn="Stop Generation",
|
338 |
multimodal=True,
|
|
|
4 |
import json
|
5 |
import time
|
6 |
import asyncio
|
7 |
+
import re
|
8 |
from threading import Thread
|
9 |
|
10 |
import gradio as gr
|
|
|
13 |
import numpy as np
|
14 |
from PIL import Image
|
15 |
import edge_tts
|
16 |
+
import subprocess
|
17 |
|
18 |
+
# Install flash-attn with our environment flag (if needed)
|
19 |
+
subprocess.run(
|
20 |
+
'pip install flash-attn --no-build-isolation',
|
21 |
+
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
|
22 |
+
shell=True
|
|
|
23 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
# -------------------------------
|
26 |
# FLUX.1 IMAGE GENERATION SETUP
|
27 |
+
# -------------------------------
|
28 |
MAX_SEED = np.iinfo(np.int32).max
|
29 |
|
30 |
def save_image(img: Image.Image) -> str:
|
31 |
+
"""Save a PIL image with a unique filename and return its path."""
|
32 |
unique_name = str(uuid.uuid4()) + ".png"
|
33 |
img.save(unique_name)
|
34 |
return unique_name
|
|
|
38 |
seed = random.randint(0, MAX_SEED)
|
39 |
return seed
|
40 |
|
41 |
+
# Load Flux.1 pipeline and LoRA weights
|
42 |
+
from diffusers import DiffusionPipeline
|
43 |
+
|
44 |
base_model = "black-forest-labs/FLUX.1-dev"
|
45 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
46 |
lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
|
|
|
67 |
"prompt": "{prompt}",
|
68 |
},
|
69 |
]
|
70 |
+
styles = {s["name"]: s["prompt"] for s in style_list}
|
71 |
DEFAULT_STYLE_NAME = "3840 x 2160"
|
72 |
STYLE_NAMES = list(styles.keys())
|
73 |
|
74 |
def apply_style(style_name: str, positive: str) -> str:
|
75 |
return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
|
76 |
|
77 |
+
@spaces.GPU(duration=60, enable_queue=True)
|
78 |
def generate_image_flux(
|
79 |
prompt: str,
|
80 |
seed: int = 0,
|
|
|
85 |
style_name: str = DEFAULT_STYLE_NAME,
|
86 |
progress=gr.Progress(track_tqdm=True),
|
87 |
):
|
88 |
+
"""Generate an image using the Flux.1 pipeline with a chosen style."""
|
89 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
90 |
positive_prompt = apply_style(style_name, prompt)
|
91 |
if trigger_word:
|
|
|
102 |
image_paths = [save_image(img) for img in images]
|
103 |
return image_paths, seed
|
104 |
|
105 |
+
# -------------------------------
|
106 |
+
# SMOLVLM2 SETUP (Default Text/Multimodal Model)
|
107 |
+
# -------------------------------
|
108 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
|
109 |
+
# Load the SmolVLM2 processor and model with flash attention enabled.
|
110 |
+
smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
|
111 |
+
smol_model = AutoModelForImageTextToText.from_pretrained(
|
112 |
+
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
113 |
+
_attn_implementation="flash_attention_2",
|
114 |
+
torch_dtype=torch.bfloat16
|
115 |
+
).to("cuda:0")
|
116 |
+
|
117 |
+
# -------------------------------
|
118 |
+
# UTILITY FUNCTIONS
|
119 |
+
# -------------------------------
|
120 |
+
def progress_bar_html(label: str) -> str:
|
121 |
+
"""
|
122 |
+
Returns an HTML snippet for an animated progress bar with a given label.
|
123 |
+
"""
|
124 |
+
return f'''
|
125 |
+
<div style="display: flex; align-items: center;">
|
126 |
+
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
|
127 |
+
<div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
|
128 |
+
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
|
129 |
+
</div>
|
130 |
+
</div>
|
131 |
+
<style>
|
132 |
+
@keyframes loading {{
|
133 |
+
0% {{ transform: translateX(-100%); }}
|
134 |
+
100% {{ transform: translateX(100%); }}
|
135 |
+
}}
|
136 |
+
</style>
|
137 |
+
'''
|
138 |
+
|
139 |
+
# TTS voices (if using TTS commands)
|
140 |
+
TTS_VOICES = [
|
141 |
+
"en-US-JennyNeural", # @tts1
|
142 |
+
"en-US-GuyNeural", # @tts2
|
143 |
+
]
|
144 |
+
|
145 |
+
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
146 |
+
"""Convert text to speech using Edge TTS and save the output as MP3."""
|
147 |
+
communicate = edge_tts.Communicate(text, voice)
|
148 |
+
await communicate.save(output_file)
|
149 |
+
return output_file
|
150 |
+
|
151 |
+
# -------------------------------
|
152 |
+
# CHAT / MULTIMODAL GENERATION FUNCTION
|
153 |
+
# -------------------------------
|
154 |
@spaces.GPU
|
155 |
def generate(
|
156 |
input_dict: dict,
|
157 |
chat_history: list[dict],
|
158 |
+
max_tokens: int = 200,
|
|
|
|
|
|
|
|
|
159 |
):
|
160 |
"""
|
161 |
+
Generates chatbot responses using SmolVLM2 by default—with support for multimodal inputs and TTS.
|
162 |
Special commands:
|
|
|
163 |
- "@image": triggers image generation using the Flux.1 pipeline.
|
164 |
+
- "@tts1" or "@tts2": triggers text-to-speech after generation.
|
165 |
"""
|
166 |
text = input_dict["text"]
|
167 |
files = input_dict.get("files", [])
|
168 |
|
169 |
+
# If the query starts with "@image", use Flux.1 to generate an image.
|
170 |
if text.strip().lower().startswith("@image"):
|
|
|
171 |
prompt = text[len("@image"):].strip()
|
172 |
yield progress_bar_html("Hold Tight Generating Flux.1 Image")
|
173 |
image_paths, used_seed = generate_image_flux(
|
|
|
181 |
progress=gr.Progress(track_tqdm=True),
|
182 |
)
|
183 |
yield gr.Image(image_paths[0])
|
184 |
+
return
|
185 |
|
186 |
+
# Handle TTS commands if present.
|
187 |
tts_prefix = "@tts"
|
188 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
189 |
+
voice = None
|
190 |
+
if is_tts:
|
191 |
+
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
192 |
+
if voice_index:
|
193 |
+
voice = TTS_VOICES[voice_index - 1]
|
194 |
+
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
+
# Now use SmolVLM2 for chat/multimodal text generation.
|
197 |
+
yield "Processing with SmolVLM2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
+
# Build conversation messages based on input and history.
|
200 |
+
user_content = []
|
201 |
+
media_queue = []
|
202 |
+
if chat_history == []:
|
203 |
+
text = text.strip()
|
204 |
+
for file in files:
|
205 |
+
if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
|
206 |
+
media_queue.append({"type": "image", "path": file})
|
207 |
+
elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
|
208 |
+
media_queue.append({"type": "video", "path": file})
|
209 |
+
if "<image>" in text or "<video>" in text:
|
210 |
+
parts = re.split(r'(<image>|<video>)', text)
|
211 |
+
for part in parts:
|
212 |
+
if part == "<image>" and media_queue:
|
213 |
+
user_content.append(media_queue.pop(0))
|
214 |
+
elif part == "<video>" and media_queue:
|
215 |
+
user_content.append(media_queue.pop(0))
|
216 |
+
elif part.strip():
|
217 |
+
user_content.append({"type": "text", "text": part.strip()})
|
218 |
+
else:
|
219 |
+
user_content.append({"type": "text", "text": text})
|
220 |
+
for media in media_queue:
|
221 |
+
user_content.append(media)
|
222 |
+
resulting_messages = [{"role": "user", "content": user_content}]
|
223 |
else:
|
224 |
+
resulting_messages = []
|
225 |
+
user_content = []
|
226 |
+
media_queue = []
|
227 |
+
for hist in chat_history:
|
228 |
+
if hist["role"] == "user" and isinstance(hist["content"], tuple):
|
229 |
+
file_name = hist["content"][0]
|
230 |
+
if file_name.endswith((".png", ".jpg", ".jpeg")):
|
231 |
+
media_queue.append({"type": "image", "path": file_name})
|
232 |
+
elif file_name.endswith(".mp4"):
|
233 |
+
media_queue.append({"type": "video", "path": file_name})
|
234 |
+
for hist in chat_history:
|
235 |
+
if hist["role"] == "user" and isinstance(hist["content"], str):
|
236 |
+
txt = hist["content"]
|
237 |
+
parts = re.split(r'(<image>|<video>)', txt)
|
238 |
+
for part in parts:
|
239 |
+
if part == "<image>" and media_queue:
|
240 |
+
user_content.append(media_queue.pop(0))
|
241 |
+
elif part == "<video>" and media_queue:
|
242 |
+
user_content.append(media_queue.pop(0))
|
243 |
+
elif part.strip():
|
244 |
+
user_content.append({"type": "text", "text": part.strip()})
|
245 |
+
elif hist["role"] == "assistant":
|
246 |
+
resulting_messages.append({
|
247 |
+
"role": "user",
|
248 |
+
"content": user_content
|
249 |
+
})
|
250 |
+
resulting_messages.append({
|
251 |
+
"role": "assistant",
|
252 |
+
"content": [{"type": "text", "text": hist["content"]}]
|
253 |
+
})
|
254 |
+
user_content = []
|
255 |
+
if not resulting_messages:
|
256 |
+
resulting_messages = [{"role": "user", "content": user_content}]
|
257 |
+
|
258 |
+
if text == "" and not files:
|
259 |
+
yield "Please input a query and optionally image(s)."
|
260 |
+
return
|
261 |
+
if text == "" and files:
|
262 |
+
yield "Please input a text query along with the image(s)."
|
263 |
+
return
|
264 |
|
265 |
+
print("resulting_messages", resulting_messages)
|
266 |
+
inputs = smol_processor.apply_chat_template(
|
267 |
+
resulting_messages,
|
268 |
+
add_generation_prompt=True,
|
269 |
+
tokenize=True,
|
270 |
+
return_dict=True,
|
271 |
+
return_tensors="pt",
|
272 |
+
)
|
273 |
+
inputs = inputs.to(smol_model.device)
|
274 |
|
275 |
+
streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
|
276 |
+
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
|
277 |
+
thread = Thread(target=smol_model.generate, kwargs=generation_args)
|
278 |
+
thread.start()
|
279 |
|
280 |
+
yield "..."
|
281 |
+
buffer = ""
|
282 |
+
for new_text in streamer:
|
283 |
+
buffer += new_text
|
284 |
+
time.sleep(0.01)
|
285 |
+
yield buffer
|
286 |
|
287 |
+
if is_tts and voice:
|
288 |
+
final_response = buffer
|
289 |
+
output_file = asyncio.run(text_to_speech(final_response, voice))
|
290 |
+
yield gr.Audio(output_file, autoplay=True)
|
291 |
+
|
292 |
+
# -------------------------------
|
293 |
# GRADIO CHAT INTERFACE
|
294 |
+
# -------------------------------
|
295 |
+
DESCRIPTION = "# Flux.1 Realism 🥖 + SmolVLM2 Chat"
|
296 |
+
if not torch.cuda.is_available():
|
297 |
+
DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>"
|
298 |
+
|
299 |
+
css = '''
|
300 |
+
h1 {
|
301 |
+
text-align: center;
|
302 |
+
display: block;
|
303 |
+
}
|
304 |
+
#duplicate-button {
|
305 |
+
margin: auto;
|
306 |
+
color: #fff;
|
307 |
+
background: #1565c0;
|
308 |
+
border-radius: 100vh;
|
309 |
+
}
|
310 |
+
'''
|
311 |
+
|
312 |
demo = gr.ChatInterface(
|
313 |
fn=generate,
|
314 |
additional_inputs=[
|
315 |
+
gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"),
|
|
|
|
|
|
|
|
|
316 |
],
|
317 |
examples=[
|
318 |
+
[{"text": "@image A futuristic cityscape at dusk in hyper-realistic 8K"}],
|
319 |
+
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
|
320 |
+
[{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
|
321 |
+
[{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
|
|
|
|
|
|
|
|
|
|
|
322 |
],
|
323 |
cache_examples=False,
|
324 |
type="messages",
|
|
|
327 |
fill_height=True,
|
328 |
textbox=gr.MultimodalTextbox(
|
329 |
label="Query Input",
|
330 |
+
file_types=["image", ".mp4"],
|
331 |
file_count="multiple",
|
332 |
+
placeholder="Type text and/or upload media. Use '@image' for Flux.1 image gen, '@tts1' or '@tts2' for TTS."
|
333 |
),
|
334 |
stop_btn="Stop Generation",
|
335 |
multimodal=True,
|