prithivMLmods commited on
Commit
2230883
·
verified ·
1 Parent(s): 8e870ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -135
app.py CHANGED
@@ -1,14 +1,118 @@
1
- import spaces
 
 
 
 
 
 
 
2
  import gradio as gr
 
3
  import torch
 
4
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
5
  from diffusers import DiffusionPipeline
6
- import random
7
- import uuid
8
- from typing import Tuple
9
- import numpy as np
10
 
11
- def save_image(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  unique_name = str(uuid.uuid4()) + ".png"
13
  img.save(unique_name)
14
  return unique_name
@@ -18,20 +122,15 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
18
  seed = random.randint(0, MAX_SEED)
19
  return seed
20
 
21
- MAX_SEED = np.iinfo(np.int32).max
22
-
23
- if not torch.cuda.is_available():
24
- DESCRIPTIONz += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
25
-
26
  base_model = "black-forest-labs/FLUX.1-dev"
27
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
28
-
29
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
30
- trigger_word = "Super Realism" # Leave trigger_word blank if not used.
31
-
32
  pipe.load_lora_weights(lora_repo)
33
  pipe.to("cuda")
34
 
 
35
  style_list = [
36
  {
37
  "name": "3840 x 2160",
@@ -50,9 +149,7 @@ style_list = [
50
  "prompt": "{prompt}",
51
  },
52
  ]
53
-
54
  styles = {k["name"]: k["prompt"] for k in style_list}
55
-
56
  DEFAULT_STYLE_NAME = "3840 x 2160"
57
  STYLE_NAMES = list(styles.keys())
58
 
@@ -60,7 +157,7 @@ def apply_style(style_name: str, positive: str) -> str:
60
  return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
61
 
62
  @spaces.GPU(duration=60, enable_queue=True)
63
- def generate(
64
  prompt: str,
65
  seed: int = 0,
66
  width: int = 1024,
@@ -70,13 +167,11 @@ def generate(
70
  style_name: str = DEFAULT_STYLE_NAME,
71
  progress=gr.Progress(track_tqdm=True),
72
  ):
 
73
  seed = int(randomize_seed_fn(seed, randomize_seed))
74
-
75
  positive_prompt = apply_style(style_name, prompt)
76
-
77
  if trigger_word:
78
  positive_prompt = f"{trigger_word} {positive_prompt}"
79
-
80
  images = pipe(
81
  prompt=positive_prompt,
82
  width=width,
@@ -87,124 +182,163 @@ def generate(
87
  output_type="pil",
88
  ).images
89
  image_paths = [save_image(img) for img in images]
90
- print(image_paths)
91
  return image_paths, seed
92
 
93
- examples = [
94
- "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
95
- "Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
96
- "Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
97
- "Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."
98
- ]
99
 
100
- css = '''
101
- .gradio-container{max-width: 888px !important}
102
- h1{text-align:center}
103
- footer {
104
- visibility: hidden
105
- }
106
- .submit-btn {
107
- background-color: #e34949 !important;
108
- color: white !important;
109
- }
110
- .submit-btn:hover {
111
- background-color: #ff3b3b !important;
112
- }
113
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
116
- with gr.Row():
117
- with gr.Column(scale=1):
118
- prompt = gr.Text(
119
- label="Prompt",
120
- show_label=False,
121
- max_lines=1,
122
- placeholder="Enter your prompt",
123
- container=False,
124
- )
125
- run_button = gr.Button("Generate as ( 768 x 1024 )🤗", scale=0, elem_classes="submit-btn")
126
-
127
- with gr.Accordion("Advanced options", open=True, visible=True):
128
- seed = gr.Slider(
129
- label="Seed",
130
- minimum=0,
131
- maximum=MAX_SEED,
132
- step=1,
133
- value=0,
134
- visible=True
135
- )
136
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
137
-
138
- with gr.Row(visible=True):
139
- width = gr.Slider(
140
- label="Width",
141
- minimum=512,
142
- maximum=2048,
143
- step=64,
144
- value=768,
145
- )
146
- height = gr.Slider(
147
- label="Height",
148
- minimum=512,
149
- maximum=2048,
150
- step=64,
151
- value=1024,
152
- )
153
-
154
- with gr.Row():
155
- guidance_scale = gr.Slider(
156
- label="Guidance Scale",
157
- minimum=0.1,
158
- maximum=20.0,
159
- step=0.1,
160
- value=3.0,
161
- )
162
- num_inference_steps = gr.Slider(
163
- label="Number of inference steps",
164
- minimum=1,
165
- maximum=40,
166
- step=1,
167
- value=28,
168
- )
169
-
170
- style_selection = gr.Radio(
171
- show_label=True,
172
- container=True,
173
- interactive=True,
174
- choices=STYLE_NAMES,
175
- value=DEFAULT_STYLE_NAME,
176
- label="Quality Style",
177
- )
178
-
179
- with gr.Column(scale=2):
180
- result = gr.Gallery(label="Result", columns=1, show_label=False)
181
-
182
- gr.Examples(
183
- examples=examples,
184
- inputs=prompt,
185
- outputs=[result, seed],
186
- fn=generate,
187
- cache_examples=False,
188
- )
189
-
190
- gr.on(
191
- triggers=[
192
- prompt.submit,
193
- run_button.click,
194
- ],
195
- fn=generate,
196
- inputs=[
197
- prompt,
198
- seed,
199
- width,
200
- height,
201
- guidance_scale,
202
- randomize_seed,
203
- style_selection,
204
- ],
205
- outputs=[result, seed],
206
- api_name="run",
207
- )
208
 
209
  if __name__ == "__main__":
210
- demo.queue(max_size=40).launch()
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
  import gradio as gr
10
+ import spaces
11
  import torch
12
+ import numpy as np
13
  from PIL import Image
14
+ import edge_tts
15
+
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ TextIteratorStreamer,
20
+ Qwen2_5_VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ )
23
+ from transformers.image_utils import load_image
24
  from diffusers import DiffusionPipeline
 
 
 
 
25
 
26
+ DESCRIPTION = "# Flux.1 Realism 🥖"
27
+ if not torch.cuda.is_available():
28
+ DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
29
+
30
+ css = '''
31
+ h1 {
32
+ text-align: center;
33
+ display: block;
34
+ }
35
+
36
+ #duplicate-button {
37
+ margin: auto;
38
+ color: #fff;
39
+ background: #1565c0;
40
+ border-radius: 100vh;
41
+ }
42
+ '''
43
+
44
+ MAX_MAX_NEW_TOKENS = 2048
45
+ DEFAULT_MAX_NEW_TOKENS = 1024
46
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
47
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48
+
49
+
50
+ # Load text-only model and tokenizer
51
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
52
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ model_id,
55
+ device_map="auto",
56
+ torch_dtype=torch.bfloat16,
57
+ )
58
+ model.eval()
59
+
60
+ TTS_VOICES = [
61
+ "en-US-JennyNeural", # @tts1
62
+ "en-US-GuyNeural", # @tts2
63
+ ]
64
+
65
+ # Load multimodal Qwen model & processor
66
+ MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
67
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
69
+ MODEL_ID,
70
+ trust_remote_code=True,
71
+ torch_dtype=torch.float16
72
+ ).to("cuda").eval()
73
+
74
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
75
+ """Convert text to speech using Edge TTS and save as MP3"""
76
+ communicate = edge_tts.Communicate(text, voice)
77
+ await communicate.save(output_file)
78
+ return output_file
79
+
80
+ def clean_chat_history(chat_history):
81
+ """
82
+ Filter out any chat entries whose "content" is not a string.
83
+ This helps prevent errors when concatenating previous messages.
84
+ """
85
+ cleaned = []
86
+ for msg in chat_history:
87
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
88
+ cleaned.append(msg)
89
+ return cleaned
90
+
91
+ def progress_bar_html(label: str) -> str:
92
+ """
93
+ Returns an HTML snippet for a thin progress bar with a label.
94
+ The progress bar is styled as a dark red animated bar.
95
+ """
96
+ return f'''
97
+ <div style="display: flex; align-items: center;">
98
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
99
+ <div style="width: 110px; height: 5px; background-color: #f0f0f0; border-radius: 2px; overflow: hidden;">
100
+ <div style="width: 100%; height: 100%; background-color: #ff5900; animation: loading 1.5s linear infinite;"></div>
101
+ </div>
102
+ </div>
103
+ <style>
104
+ @keyframes loading {{
105
+ 0% {{ transform: translateX(-100%); }}
106
+ 100% {{ transform: translateX(100%); }}
107
+ }}
108
+ </style>
109
+ '''
110
+
111
+ # FLUX.1 IMAGE GENERATION SETUP
112
+ MAX_SEED = np.iinfo(np.int32).max
113
+
114
+ def save_image(img: Image.Image) -> str:
115
+ """Save a PIL image with a unique filename and return the path."""
116
  unique_name = str(uuid.uuid4()) + ".png"
117
  img.save(unique_name)
118
  return unique_name
 
122
  seed = random.randint(0, MAX_SEED)
123
  return seed
124
 
125
+ # Initialize Flux.1 pipeline
 
 
 
 
126
  base_model = "black-forest-labs/FLUX.1-dev"
127
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
 
128
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
129
+ trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
 
130
  pipe.load_lora_weights(lora_repo)
131
  pipe.to("cuda")
132
 
133
+ # Define style prompts for Flux.1
134
  style_list = [
135
  {
136
  "name": "3840 x 2160",
 
149
  "prompt": "{prompt}",
150
  },
151
  ]
 
152
  styles = {k["name"]: k["prompt"] for k in style_list}
 
153
  DEFAULT_STYLE_NAME = "3840 x 2160"
154
  STYLE_NAMES = list(styles.keys())
155
 
 
157
  return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
158
 
159
  @spaces.GPU(duration=60, enable_queue=True)
160
+ def generate_image_flux(
161
  prompt: str,
162
  seed: int = 0,
163
  width: int = 1024,
 
167
  style_name: str = DEFAULT_STYLE_NAME,
168
  progress=gr.Progress(track_tqdm=True),
169
  ):
170
+ """Generate images using the Flux.1 pipeline with style prompts."""
171
  seed = int(randomize_seed_fn(seed, randomize_seed))
 
172
  positive_prompt = apply_style(style_name, prompt)
 
173
  if trigger_word:
174
  positive_prompt = f"{trigger_word} {positive_prompt}"
 
175
  images = pipe(
176
  prompt=positive_prompt,
177
  width=width,
 
182
  output_type="pil",
183
  ).images
184
  image_paths = [save_image(img) for img in images]
 
185
  return image_paths, seed
186
 
187
+ # CHAT GENERATION FUNCTION (TEXT & MULTIMODAL)
 
 
 
 
 
188
 
189
+ @spaces.GPU
190
+ def generate(
191
+ input_dict: dict,
192
+ chat_history: list[dict],
193
+ max_new_tokens: int = 1024,
194
+ temperature: float = 0.6,
195
+ top_p: float = 0.9,
196
+ top_k: int = 50,
197
+ repetition_penalty: float = 1.2,
198
+ ):
199
+ """
200
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
201
+ Special commands:
202
+ - "@tts1" or "@tts2": triggers text-to-speech.
203
+ - "@image": triggers image generation using the Flux.1 pipeline.
204
+ """
205
+ text = input_dict["text"]
206
+ files = input_dict.get("files", [])
207
+
208
+ # If the text begins with "@image", use Flux for image generation.
209
+ if text.strip().lower().startswith("@image"):
210
+ # Remove the "@image" tag and use the remainder as the prompt.
211
+ prompt = text[len("@image"):].strip()
212
+ yield progress_bar_html("Hold Tight Generating Flux.1 Image")
213
+ image_paths, used_seed = generate_image_flux(
214
+ prompt=prompt,
215
+ seed=1,
216
+ width=1024,
217
+ height=1024,
218
+ guidance_scale=3,
219
+ randomize_seed=True,
220
+ style_name=DEFAULT_STYLE_NAME,
221
+ progress=gr.Progress(track_tqdm=True),
222
+ )
223
+ yield gr.Image(image_paths[0])
224
+ return # Exit early after image generation.
225
+
226
+ # Check if a TTS command is issued.
227
+ tts_prefix = "@tts"
228
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
229
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
230
+
231
+ if is_tts and voice_index:
232
+ voice = TTS_VOICES[voice_index - 1]
233
+ text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
234
+ # Clear previous chat history for a fresh TTS request.
235
+ conversation = [{"role": "user", "content": text}]
236
+ else:
237
+ voice = None
238
+ # Remove any stray @tts tags and build the conversation history.
239
+ text = text.replace(tts_prefix, "").strip()
240
+ conversation = clean_chat_history(chat_history)
241
+ conversation.append({"role": "user", "content": text})
242
+
243
+ # Handle multimodal input if files are provided.
244
+ if files:
245
+ if len(files) > 1:
246
+ images = [load_image(image) for image in files]
247
+ elif len(files) == 1:
248
+ images = [load_image(files[0])]
249
+ else:
250
+ images = []
251
+ messages = [{
252
+ "role": "user",
253
+ "content": [
254
+ *[{"type": "image", "image": image} for image in images],
255
+ {"type": "text", "text": text},
256
+ ]
257
+ }]
258
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
259
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
260
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
261
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
262
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
263
+ thread.start()
264
+
265
+ buffer = ""
266
+ yield progress_bar_html("Thinking...")
267
+ for new_text in streamer:
268
+ buffer += new_text
269
+ buffer = buffer.replace("<|im_end|>", "")
270
+ time.sleep(0.01)
271
+ yield buffer
272
+ else:
273
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
274
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
275
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
276
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
277
+ input_ids = input_ids.to(model.device)
278
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
279
+ generation_kwargs = {
280
+ "input_ids": input_ids,
281
+ "streamer": streamer,
282
+ "max_new_tokens": max_new_tokens,
283
+ "do_sample": True,
284
+ "top_p": top_p,
285
+ "top_k": top_k,
286
+ "temperature": temperature,
287
+ "num_beams": 1,
288
+ "repetition_penalty": repetition_penalty,
289
+ }
290
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
291
+ t.start()
292
+
293
+ outputs = []
294
+ yield progress_bar_html("Thinking...")
295
+ for new_text in streamer:
296
+ outputs.append(new_text)
297
+ yield "".join(outputs)
298
+
299
+ final_response = "".join(outputs)
300
+ yield final_response
301
+
302
+ # If TTS was requested, convert the final response to speech.
303
+ if is_tts and voice:
304
+ output_file = asyncio.run(text_to_speech(final_response, voice))
305
+ yield gr.Audio(output_file, autoplay=True)
306
 
307
+ # GRADIO CHAT INTERFACE
308
+ demo = gr.ChatInterface(
309
+ fn=generate,
310
+ additional_inputs=[
311
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
312
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
313
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
314
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
315
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
316
+ ],
317
+ examples=[
318
+ ["@image Chocolate dripping from a donut against a yellow background, in the style of hyper-realistic 8K"],
319
+ ["@image Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250"],
320
+ ["@image Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6"],
321
+ ["@image Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."]
322
+ ["Python Program for Array Rotation"],
323
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
324
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
325
+ [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
326
+ ["@tts2 What causes rainbows to form?"],
327
+ ],
328
+ cache_examples=False,
329
+ type="messages",
330
+ description=DESCRIPTION,
331
+ css=css,
332
+ fill_height=True,
333
+ textbox=gr.MultimodalTextbox(
334
+ label="Query Input",
335
+ file_types=["image"],
336
+ file_count="multiple",
337
+ placeholder="‎ @image-flux.1 image gen, @tts1, @tts2-voices, default [text, vision]"
338
+ ),
339
+ stop_btn="Stop Generation",
340
+ multimodal=True,
341
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  if __name__ == "__main__":
344
+ demo.queue(max_size=20).launch(share=True)