Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -22,10 +22,15 @@ subprocess.run(
|
|
22 |
shell=True
|
23 |
)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
# -------------------------------
|
26 |
# CONFIGURATION & UTILITY FUNCTIONS
|
27 |
# -------------------------------
|
28 |
-
MAX_SEED =
|
29 |
|
30 |
def save_image(img: Image.Image) -> str:
|
31 |
"""Save a PIL image with a unique filename and return its path."""
|
@@ -38,79 +43,66 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
38 |
seed = random.randint(0, MAX_SEED)
|
39 |
return seed
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# -------------------------------
|
46 |
-
# FLUX
|
47 |
# -------------------------------
|
48 |
from diffusers import DiffusionPipeline
|
49 |
|
50 |
base_model = "black-forest-labs/FLUX.1-dev"
|
51 |
-
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=
|
52 |
-
lora_repo = "
|
53 |
-
trigger_word = "
|
54 |
pipe.load_lora_weights(lora_repo)
|
55 |
pipe.to("cuda")
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
seed: int = 0,
|
87 |
-
width: int = 1024,
|
88 |
-
height: int = 1024,
|
89 |
-
guidance_scale: float = 3,
|
90 |
-
randomize_seed: bool = False,
|
91 |
-
style_name: str = DEFAULT_STYLE_NAME,
|
92 |
-
progress=gr.Progress(track_tqdm=True),
|
93 |
-
):
|
94 |
-
"""Generate an image using the Flux.1 pipeline with a chosen style."""
|
95 |
-
torch.cuda.empty_cache() # Clear unused GPU memory to prevent allocation errors
|
96 |
-
seed = int(randomize_seed_fn(seed, randomize_seed))
|
97 |
-
positive_prompt = apply_style(style_name, prompt)
|
98 |
-
if trigger_word:
|
99 |
-
positive_prompt = f"{trigger_word} {positive_prompt}"
|
100 |
-
# Wrap the diffusion call in no_grad to avoid unnecessary gradient state.
|
101 |
-
with torch.no_grad():
|
102 |
-
images = pipe(
|
103 |
-
prompt=positive_prompt,
|
104 |
-
width=width,
|
105 |
-
height=height,
|
106 |
-
guidance_scale=guidance_scale,
|
107 |
-
num_inference_steps=28,
|
108 |
-
num_images_per_prompt=1,
|
109 |
-
output_type="pil",
|
110 |
-
).images
|
111 |
-
torch.cuda.synchronize() # Ensure all CUDA operations have completed
|
112 |
-
image_paths = [save_image(img) for img in images]
|
113 |
-
return image_paths, seed
|
114 |
|
115 |
# -------------------------------
|
116 |
# SMOLVLM2 SETUP (Default Text/Multimodal Model)
|
@@ -121,31 +113,12 @@ smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Inst
|
|
121 |
smol_model = AutoModelForImageTextToText.from_pretrained(
|
122 |
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
123 |
_attn_implementation="flash_attention_2",
|
124 |
-
torch_dtype=
|
125 |
).to("cuda:0")
|
126 |
|
127 |
# -------------------------------
|
128 |
-
# UTILITY FUNCTIONS
|
129 |
# -------------------------------
|
130 |
-
def progress_bar_html(label: str) -> str:
|
131 |
-
"""
|
132 |
-
Returns an HTML snippet for an animated progress bar with a given label.
|
133 |
-
"""
|
134 |
-
return f'''
|
135 |
-
<div style="display: flex; align-items: center;">
|
136 |
-
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
|
137 |
-
<div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
|
138 |
-
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
|
139 |
-
</div>
|
140 |
-
</div>
|
141 |
-
<style>
|
142 |
-
@keyframes loading {{
|
143 |
-
0% {{ transform: translateX(-100%); }}
|
144 |
-
100% {{ transform: translateX(100%); }}
|
145 |
-
}}
|
146 |
-
</style>
|
147 |
-
'''
|
148 |
-
|
149 |
TTS_VOICES = [
|
150 |
"en-US-JennyNeural", # @tts1
|
151 |
"en-US-GuyNeural", # @tts2
|
@@ -161,36 +134,32 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
|
161 |
# CHAT / MULTIMODAL GENERATION FUNCTION
|
162 |
# -------------------------------
|
163 |
@spaces.GPU
|
164 |
-
def generate(
|
165 |
-
input_dict: dict,
|
166 |
-
chat_history: list[dict],
|
167 |
-
max_tokens: int = 200,
|
168 |
-
):
|
169 |
"""
|
170 |
-
Generates chatbot responses using SmolVLM2
|
171 |
Special commands:
|
172 |
-
- "@image": triggers image generation using the
|
173 |
- "@tts1" or "@tts2": triggers text-to-speech after generation.
|
174 |
"""
|
175 |
-
torch.cuda.empty_cache()
|
176 |
text = input_dict["text"]
|
177 |
files = input_dict.get("files", [])
|
178 |
|
179 |
-
# If the query starts with "@image", use
|
180 |
if text.strip().lower().startswith("@image"):
|
181 |
prompt = text[len("@image"):].strip()
|
182 |
-
yield progress_bar_html("Hold Tight Generating Flux
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
yield gr.Image(
|
194 |
return
|
195 |
|
196 |
# Handle TTS commands if present.
|
@@ -203,7 +172,6 @@ def generate(
|
|
203 |
voice = TTS_VOICES[voice_index - 1]
|
204 |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
205 |
|
206 |
-
# Use SmolVLM2 for chat/multimodal text generation.
|
207 |
yield "Processing with SmolVLM2"
|
208 |
|
209 |
# Build conversation messages based on input and history.
|
@@ -272,7 +240,6 @@ def generate(
|
|
272 |
yield "Please input a text query along with the image(s)."
|
273 |
return
|
274 |
|
275 |
-
print("resulting_messages", resulting_messages)
|
276 |
inputs = smol_processor.apply_chat_template(
|
277 |
resulting_messages,
|
278 |
add_generation_prompt=True,
|
@@ -280,9 +247,8 @@ def generate(
|
|
280 |
return_dict=True,
|
281 |
return_tensors="pt",
|
282 |
)
|
283 |
-
# Explicitly cast pixel values to the preferred dtype to match model weights.
|
284 |
if "pixel_values" in inputs:
|
285 |
-
inputs["pixel_values"] = inputs["pixel_values"].to(
|
286 |
inputs = inputs.to(smol_model.device)
|
287 |
|
288 |
streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
|
@@ -305,7 +271,7 @@ def generate(
|
|
305 |
# -------------------------------
|
306 |
# GRADIO CHAT INTERFACE
|
307 |
# -------------------------------
|
308 |
-
DESCRIPTION = "# Flux
|
309 |
if not torch.cuda.is_available():
|
310 |
DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>"
|
311 |
|
@@ -328,7 +294,7 @@ demo = gr.ChatInterface(
|
|
328 |
gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"),
|
329 |
],
|
330 |
examples=[
|
331 |
-
[{"text": "@image A futuristic cityscape at dusk in hyper-realistic
|
332 |
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
|
333 |
[{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
|
334 |
[{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
|
@@ -342,7 +308,7 @@ demo = gr.ChatInterface(
|
|
342 |
label="Query Input",
|
343 |
file_types=["image", ".mp4"],
|
344 |
file_count="multiple",
|
345 |
-
placeholder="Type text and/or upload media. Use '@image' for
|
346 |
),
|
347 |
stop_btn="Stop Generation",
|
348 |
multimodal=True,
|
|
|
22 |
shell=True
|
23 |
)
|
24 |
|
25 |
+
# Set torch backend configurations for Flux RealismLora
|
26 |
+
torch.backends.cudnn.deterministic = True
|
27 |
+
torch.backends.cudnn.benchmark = False
|
28 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
29 |
+
|
30 |
# -------------------------------
|
31 |
# CONFIGURATION & UTILITY FUNCTIONS
|
32 |
# -------------------------------
|
33 |
+
MAX_SEED = 2**32 - 1
|
34 |
|
35 |
def save_image(img: Image.Image) -> str:
|
36 |
"""Save a PIL image with a unique filename and return its path."""
|
|
|
43 |
seed = random.randint(0, MAX_SEED)
|
44 |
return seed
|
45 |
|
46 |
+
def progress_bar_html(label: str) -> str:
|
47 |
+
"""
|
48 |
+
Returns an HTML snippet for an animated progress bar with a given label.
|
49 |
+
"""
|
50 |
+
return f'''
|
51 |
+
<div style="display: flex; align-items: center;">
|
52 |
+
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
|
53 |
+
<div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
|
54 |
+
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
|
55 |
+
</div>
|
56 |
+
</div>
|
57 |
+
<style>
|
58 |
+
@keyframes loading {{
|
59 |
+
0% {{ transform: translateX(-100%); }}
|
60 |
+
100% {{ transform: translateX(100%); }}
|
61 |
+
}}
|
62 |
+
</style>
|
63 |
+
'''
|
64 |
|
65 |
# -------------------------------
|
66 |
+
# FLUX REALISMLORA IMAGE GENERATION SETUP (New Implementation)
|
67 |
# -------------------------------
|
68 |
from diffusers import DiffusionPipeline
|
69 |
|
70 |
base_model = "black-forest-labs/FLUX.1-dev"
|
71 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
72 |
+
lora_repo = "XLabs-AI/flux-RealismLora"
|
73 |
+
trigger_word = "" # No trigger word used.
|
74 |
pipe.load_lora_weights(lora_repo)
|
75 |
pipe.to("cuda")
|
76 |
|
77 |
+
@spaces.GPU()
|
78 |
+
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
79 |
+
# Set random seed for reproducibility
|
80 |
+
if randomize_seed:
|
81 |
+
seed = random.randint(0, MAX_SEED)
|
82 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
83 |
+
|
84 |
+
# Update progress bar (0% at start)
|
85 |
+
progress(0, "Starting image generation...")
|
86 |
+
|
87 |
+
# Simulate progress updates during the steps
|
88 |
+
for i in range(1, steps + 1):
|
89 |
+
if steps >= 10 and i % (steps // 10) == 0:
|
90 |
+
progress(i / steps * 100, f"Processing step {i} of {steps}...")
|
91 |
+
|
92 |
+
# Generate image using the pipeline
|
93 |
+
image = pipe(
|
94 |
+
prompt=f"{prompt} {trigger_word}",
|
95 |
+
num_inference_steps=steps,
|
96 |
+
guidance_scale=cfg_scale,
|
97 |
+
width=width,
|
98 |
+
height=height,
|
99 |
+
generator=generator,
|
100 |
+
joint_attention_kwargs={"scale": lora_scale},
|
101 |
+
).images[0]
|
102 |
+
|
103 |
+
# Final progress update (100%)
|
104 |
+
progress(100, "Completed!")
|
105 |
+
yield image, seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
# -------------------------------
|
108 |
# SMOLVLM2 SETUP (Default Text/Multimodal Model)
|
|
|
113 |
smol_model = AutoModelForImageTextToText.from_pretrained(
|
114 |
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
115 |
_attn_implementation="flash_attention_2",
|
116 |
+
torch_dtype=torch.float16
|
117 |
).to("cuda:0")
|
118 |
|
119 |
# -------------------------------
|
120 |
+
# TTS UTILITY FUNCTIONS
|
121 |
# -------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
TTS_VOICES = [
|
123 |
"en-US-JennyNeural", # @tts1
|
124 |
"en-US-GuyNeural", # @tts2
|
|
|
134 |
# CHAT / MULTIMODAL GENERATION FUNCTION
|
135 |
# -------------------------------
|
136 |
@spaces.GPU
|
137 |
+
def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200):
|
|
|
|
|
|
|
|
|
138 |
"""
|
139 |
+
Generates chatbot responses using SmolVLM2 with support for multimodal inputs and TTS.
|
140 |
Special commands:
|
141 |
+
- "@image": triggers image generation using the RealismLora flux implementation.
|
142 |
- "@tts1" or "@tts2": triggers text-to-speech after generation.
|
143 |
"""
|
144 |
+
torch.cuda.empty_cache()
|
145 |
text = input_dict["text"]
|
146 |
files = input_dict.get("files", [])
|
147 |
|
148 |
+
# If the query starts with "@image", use RealismLora to generate an image.
|
149 |
if text.strip().lower().startswith("@image"):
|
150 |
prompt = text[len("@image"):].strip()
|
151 |
+
yield progress_bar_html("Hold Tight Generating Flux RealismLora Image")
|
152 |
+
# Default parameters for RealismLora generation
|
153 |
+
default_cfg_scale = 3.2
|
154 |
+
default_steps = 32
|
155 |
+
default_width = 1152
|
156 |
+
default_height = 896
|
157 |
+
default_seed = 3981632454
|
158 |
+
default_lora_scale = 0.85
|
159 |
+
# Call the new run_lora function and yield its final result
|
160 |
+
for result in run_lora(prompt, default_cfg_scale, default_steps, True, default_seed, default_width, default_height, default_lora_scale, progress=gr.Progress(track_tqdm=True)):
|
161 |
+
final_result = result
|
162 |
+
yield gr.Image(final_result[0])
|
163 |
return
|
164 |
|
165 |
# Handle TTS commands if present.
|
|
|
172 |
voice = TTS_VOICES[voice_index - 1]
|
173 |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
174 |
|
|
|
175 |
yield "Processing with SmolVLM2"
|
176 |
|
177 |
# Build conversation messages based on input and history.
|
|
|
240 |
yield "Please input a text query along with the image(s)."
|
241 |
return
|
242 |
|
|
|
243 |
inputs = smol_processor.apply_chat_template(
|
244 |
resulting_messages,
|
245 |
add_generation_prompt=True,
|
|
|
247 |
return_dict=True,
|
248 |
return_tensors="pt",
|
249 |
)
|
|
|
250 |
if "pixel_values" in inputs:
|
251 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
|
252 |
inputs = inputs.to(smol_model.device)
|
253 |
|
254 |
streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
|
|
|
271 |
# -------------------------------
|
272 |
# GRADIO CHAT INTERFACE
|
273 |
# -------------------------------
|
274 |
+
DESCRIPTION = "# Flux RealismLora + SmolVLM2 Chat"
|
275 |
if not torch.cuda.is_available():
|
276 |
DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>"
|
277 |
|
|
|
294 |
gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"),
|
295 |
],
|
296 |
examples=[
|
297 |
+
[{"text": "@image A futuristic cityscape at dusk in hyper-realistic style"}],
|
298 |
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
|
299 |
[{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
|
300 |
[{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
|
|
|
308 |
label="Query Input",
|
309 |
file_types=["image", ".mp4"],
|
310 |
file_count="multiple",
|
311 |
+
placeholder="Type text and/or upload media. Use '@image' for image gen, '@tts1' or '@tts2' for TTS."
|
312 |
),
|
313 |
stop_btn="Stop Generation",
|
314 |
multimodal=True,
|