prithivMLmods commited on
Commit
8b14dd2
·
verified ·
1 Parent(s): 622d235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -278
app.py CHANGED
@@ -1,110 +1,37 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
- import gradio as gr
10
  import spaces
 
11
  import torch
12
- import numpy as np
13
  from PIL import Image
14
- import edge_tts
15
-
16
- from transformers import (
17
- AutoModelForCausalLM,
18
- AutoTokenizer,
19
- TextIteratorStreamer,
20
- Qwen2VLForConditionalGeneration,
21
- AutoProcessor,
22
- )
23
- from transformers.image_utils import load_image
24
  from diffusers import DiffusionPipeline
 
 
 
 
25
 
26
- DESCRIPTION = """
27
- # QwQ Edge 💬 with Flux.1
28
- """
29
-
30
- css = '''
31
- h1 {
32
- text-align: center;
33
- display: block;
34
- }
35
-
36
- #duplicate-button {
37
- margin: auto;
38
- color: #fff;
39
- background: #1565c0;
40
- border-radius: 100vh;
41
- }
42
- '''
43
-
44
- MAX_MAX_NEW_TOKENS = 2048
45
- DEFAULT_MAX_NEW_TOKENS = 1024
46
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
47
-
48
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
-
50
- # --------------------------
51
- # Text Generation Components
52
- # --------------------------
53
-
54
- # Load text-only model and tokenizer
55
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
56
- tokenizer = AutoTokenizer.from_pretrained(model_id)
57
- model = AutoModelForCausalLM.from_pretrained(
58
- model_id,
59
- device_map="auto",
60
- torch_dtype=torch.bfloat16,
61
- )
62
- model.eval()
63
-
64
- TTS_VOICES = [
65
- "en-US-JennyNeural", # @tts1
66
- "en-US-GuyNeural", # @tts2
67
- ]
68
-
69
- # Multimodal model (text+vision)
70
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
71
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
72
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
73
- MODEL_ID,
74
- trust_remote_code=True,
75
- torch_dtype=torch.float16
76
- ).to("cuda").eval()
77
 
78
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
79
- """Convert text to speech using Edge TTS and save as MP3"""
80
- communicate = edge_tts.Communicate(text, voice)
81
- await communicate.save(output_file)
82
- return output_file
83
 
84
- def clean_chat_history(chat_history):
85
- """
86
- Filter out any chat entries whose "content" is not a string.
87
- This helps prevent errors when concatenating previous messages.
88
- """
89
- cleaned = []
90
- for msg in chat_history:
91
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
92
- cleaned.append(msg)
93
- return cleaned
94
 
95
- # --------------------------
96
- # Flux.1 Image Generation
97
- # --------------------------
98
 
99
- # Set up the Flux.1 pipeline
100
  base_model = "black-forest-labs/FLUX.1-dev"
101
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
 
102
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
103
  trigger_word = "Super Realism" # Leave trigger_word blank if not used.
 
104
  pipe.load_lora_weights(lora_repo)
105
  pipe.to("cuda")
106
 
107
- # Define style prompts
108
  style_list = [
109
  {
110
  "name": "3840 x 2160",
@@ -123,48 +50,17 @@ style_list = [
123
  "prompt": "{prompt}",
124
  },
125
  ]
 
126
  styles = {k["name"]: k["prompt"] for k in style_list}
 
127
  DEFAULT_STYLE_NAME = "3840 x 2160"
128
  STYLE_NAMES = list(styles.keys())
129
 
130
  def apply_style(style_name: str, positive: str) -> str:
131
  return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
132
 
133
- MAX_SEED = np.iinfo(np.int32).max
134
-
135
- def save_image(img: Image.Image) -> str:
136
- """Save a PIL image with a unique filename and return the path."""
137
- unique_name = str(uuid.uuid4()) + ".png"
138
- img.save(unique_name)
139
- return unique_name
140
-
141
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
142
- if randomize_seed:
143
- seed = random.randint(0, MAX_SEED)
144
- return seed
145
-
146
- def progress_bar_html(label: str) -> str:
147
- """
148
- Returns an HTML snippet for a thin progress bar with a label.
149
- The progress bar is styled as a dark red animated bar.
150
- """
151
- return f'''
152
- <div style="display: flex; align-items: center;">
153
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
154
- <div style="width: 110px; height: 5px; background-color: #f0f0f0; border-radius: 2px; overflow: hidden;">
155
- <div style="width: 100%; height: 100%; background-color: #ff5900; animation: loading 1.5s linear infinite;"></div>
156
- </div>
157
- </div>
158
- <style>
159
- @keyframes loading {{
160
- 0% {{ transform: translateX(-100%); }}
161
- 100% {{ transform: translateX(100%); }}
162
- }}
163
- </style>
164
- '''
165
-
166
- @spaces.GPU(duration=60, enable_queue=True)
167
- def generate_image_fn(
168
  prompt: str,
169
  seed: int = 0,
170
  width: int = 1024,
@@ -174,11 +70,13 @@ def generate_image_fn(
174
  style_name: str = DEFAULT_STYLE_NAME,
175
  progress=gr.Progress(track_tqdm=True),
176
  ):
177
- """Generate images using the Flux.1 pipeline."""
178
  seed = int(randomize_seed_fn(seed, randomize_seed))
 
179
  positive_prompt = apply_style(style_name, prompt)
 
180
  if trigger_word:
181
  positive_prompt = f"{trigger_word} {positive_prompt}"
 
182
  images = pipe(
183
  prompt=positive_prompt,
184
  width=width,
@@ -189,160 +87,124 @@ def generate_image_fn(
189
  output_type="pil",
190
  ).images
191
  image_paths = [save_image(img) for img in images]
 
192
  return image_paths, seed
193
 
194
- # --------------------------
195
- # Chat and Multimodal Generation
196
- # --------------------------
197
-
198
- @spaces.GPU
199
- def generate(
200
- input_dict: dict,
201
- chat_history: list[dict],
202
- max_new_tokens: int = 1024,
203
- temperature: float = 0.6,
204
- top_p: float = 0.9,
205
- top_k: int = 50,
206
- repetition_penalty: float = 1.2,
207
- ):
208
- """
209
- Generates chatbot responses with support for multimodal input, TTS, and image generation using Flux.1.
210
- Special commands:
211
- - "@tts1" or "@tts2": triggers text-to-speech.
212
- - "@image": triggers image generation using the Flux.1 pipeline.
213
- """
214
- text = input_dict["text"]
215
- files = input_dict.get("files", [])
216
-
217
- if text.strip().lower().startswith("@image"):
218
- # Remove the "@image" tag and use the rest as prompt
219
- prompt_img = text[len("@image"):].strip()
220
- # Show animated progress bar for image generation
221
- yield progress_bar_html("Generating Image")
222
- image_paths, used_seed = generate_image_fn(
223
- prompt=prompt_img,
224
- seed=1,
225
- width=1024,
226
- height=1024,
227
- guidance_scale=3,
228
- randomize_seed=True,
229
- style_name=DEFAULT_STYLE_NAME,
230
- )
231
- # Once done, yield the generated image
232
- yield gr.Image(image_paths[0])
233
- return # Exit early
234
-
235
- tts_prefix = "@tts"
236
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
237
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
238
-
239
- if is_tts and voice_index:
240
- voice = TTS_VOICES[voice_index - 1]
241
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
242
- # Clear previous chat history for a fresh TTS request.
243
- conversation = [{"role": "user", "content": text}]
244
- else:
245
- voice = None
246
- # Remove any stray @tts tags and build the conversation history.
247
- text = text.replace(tts_prefix, "").strip()
248
- conversation = clean_chat_history(chat_history)
249
- conversation.append({"role": "user", "content": text})
250
-
251
- if files:
252
- if len(files) > 1:
253
- images = [load_image(image) for image in files]
254
- elif len(files) == 1:
255
- images = [load_image(files[0])]
256
- else:
257
- images = []
258
- messages = [{
259
- "role": "user",
260
- "content": [
261
- *[{"type": "image", "image": image} for image in images],
262
- {"type": "text", "text": text},
263
- ]
264
- }]
265
- prompt_multimodal = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
266
- inputs = processor(text=[prompt_multimodal], images=images, return_tensors="pt", padding=True).to("cuda")
267
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
268
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
269
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
270
- thread.start()
271
-
272
- buffer = ""
273
- # Show animated progress bar for multimodal generation
274
- yield progress_bar_html("Thinking...")
275
- for new_text in streamer:
276
- buffer += new_text
277
- buffer = buffer.replace("<|im_end|>", "")
278
- time.sleep(0.01)
279
- yield buffer
280
- else:
281
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
282
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
283
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
284
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
285
- input_ids = input_ids.to(model.device)
286
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
287
- generation_kwargs = {
288
- "input_ids": input_ids,
289
- "streamer": streamer,
290
- "max_new_tokens": max_new_tokens,
291
- "do_sample": True,
292
- "top_p": top_p,
293
- "top_k": top_k,
294
- "temperature": temperature,
295
- "num_beams": 1,
296
- "repetition_penalty": repetition_penalty,
297
- }
298
- t = Thread(target=model.generate, kwargs=generation_kwargs)
299
- t.start()
300
-
301
- outputs = []
302
- # Show animated progress bar for text generation
303
- yield progress_bar_html("Thinking...")
304
- for new_text in streamer:
305
- outputs.append(new_text)
306
- yield "".join(outputs)
307
-
308
- final_response = "".join(outputs)
309
- yield final_response
310
-
311
- # If TTS was requested, convert the final response to speech.
312
- if is_tts and voice:
313
- output_file = asyncio.run(text_to_speech(final_response, voice))
314
- yield gr.Audio(output_file, autoplay=True)
315
 
316
- # --------------------------
317
- # Gradio Chat Interface
318
- # --------------------------
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- demo = gr.ChatInterface(
321
- fn=generate,
322
- additional_inputs=[
323
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
324
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
325
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
326
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
327
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
328
- ],
329
- examples=[
330
- ["@image A futuristic cityscape at sunset with vibrant colors"],
331
- ["Python Program for Array Rotation"],
332
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
333
- [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
334
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
335
- ["@tts2 What causes rainbows to form?"],
336
- ],
337
- cache_examples=False,
338
- type="messages",
339
- description=DESCRIPTION,
340
- css=css,
341
- fill_height=True,
342
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="‎ @tts1, @tts2-voices, @image-image gen, default [text, vision]"),
343
- stop_btn="Stop Generation",
344
- multimodal=True,
345
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  if __name__ == "__main__":
348
- demo.queue(max_size=20).launch(share=True)
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ import gradio as gr
3
  import torch
 
4
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
5
  from diffusers import DiffusionPipeline
6
+ import random
7
+ import uuid
8
+ from typing import Tuple
9
+ import numpy as np
10
 
11
+ def save_image(img):
12
+ unique_name = str(uuid.uuid4()) + ".png"
13
+ img.save(unique_name)
14
+ return unique_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
17
+ if randomize_seed:
18
+ seed = random.randint(0, MAX_SEED)
19
+ return seed
 
20
 
21
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
22
 
23
+ if not torch.cuda.is_available():
24
+ DESCRIPTIONz += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
 
25
 
 
26
  base_model = "black-forest-labs/FLUX.1-dev"
27
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
28
+
29
  lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
30
  trigger_word = "Super Realism" # Leave trigger_word blank if not used.
31
+
32
  pipe.load_lora_weights(lora_repo)
33
  pipe.to("cuda")
34
 
 
35
  style_list = [
36
  {
37
  "name": "3840 x 2160",
 
50
  "prompt": "{prompt}",
51
  },
52
  ]
53
+
54
  styles = {k["name"]: k["prompt"] for k in style_list}
55
+
56
  DEFAULT_STYLE_NAME = "3840 x 2160"
57
  STYLE_NAMES = list(styles.keys())
58
 
59
  def apply_style(style_name: str, positive: str) -> str:
60
  return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
61
 
62
+ @spaces.GPU(duration=60, enable_queue=True
63
+ def generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  prompt: str,
65
  seed: int = 0,
66
  width: int = 1024,
 
70
  style_name: str = DEFAULT_STYLE_NAME,
71
  progress=gr.Progress(track_tqdm=True),
72
  ):
 
73
  seed = int(randomize_seed_fn(seed, randomize_seed))
74
+
75
  positive_prompt = apply_style(style_name, prompt)
76
+
77
  if trigger_word:
78
  positive_prompt = f"{trigger_word} {positive_prompt}"
79
+
80
  images = pipe(
81
  prompt=positive_prompt,
82
  width=width,
 
87
  output_type="pil",
88
  ).images
89
  image_paths = [save_image(img) for img in images]
90
+ print(image_paths)
91
  return image_paths, seed
92
 
93
+ examples = [
94
+ "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
95
+ "Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
96
+ "Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
97
+ "Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."
98
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ css = '''
101
+ .gradio-container{max-width: 888px !important}
102
+ h1{text-align:center}
103
+ footer {
104
+ visibility: hidden
105
+ }
106
+ .submit-btn {
107
+ background-color: #e34949 !important;
108
+ color: white !important;
109
+ }
110
+ .submit-btn:hover {
111
+ background-color: #ff3b3b !important;
112
+ }
113
+ '''
114
 
115
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
116
+ with gr.Row():
117
+ with gr.Column(scale=1):
118
+ prompt = gr.Text(
119
+ label="Prompt",
120
+ show_label=False,
121
+ max_lines=1,
122
+ placeholder="Enter your prompt",
123
+ container=False,
124
+ )
125
+ run_button = gr.Button("Generate as ( 768 x 1024 )🤗", scale=0, elem_classes="submit-btn")
126
+
127
+ with gr.Accordion("Advanced options", open=True, visible=True):
128
+ seed = gr.Slider(
129
+ label="Seed",
130
+ minimum=0,
131
+ maximum=MAX_SEED,
132
+ step=1,
133
+ value=0,
134
+ visible=True
135
+ )
136
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
137
+
138
+ with gr.Row(visible=True):
139
+ width = gr.Slider(
140
+ label="Width",
141
+ minimum=512,
142
+ maximum=2048,
143
+ step=64,
144
+ value=768,
145
+ )
146
+ height = gr.Slider(
147
+ label="Height",
148
+ minimum=512,
149
+ maximum=2048,
150
+ step=64,
151
+ value=1024,
152
+ )
153
+
154
+ with gr.Row():
155
+ guidance_scale = gr.Slider(
156
+ label="Guidance Scale",
157
+ minimum=0.1,
158
+ maximum=20.0,
159
+ step=0.1,
160
+ value=3.0,
161
+ )
162
+ num_inference_steps = gr.Slider(
163
+ label="Number of inference steps",
164
+ minimum=1,
165
+ maximum=40,
166
+ step=1,
167
+ value=28,
168
+ )
169
+
170
+ style_selection = gr.Radio(
171
+ show_label=True,
172
+ container=True,
173
+ interactive=True,
174
+ choices=STYLE_NAMES,
175
+ value=DEFAULT_STYLE_NAME,
176
+ label="Quality Style",
177
+ )
178
+
179
+ with gr.Column(scale=2):
180
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
181
+
182
+ gr.Examples(
183
+ examples=examples,
184
+ inputs=prompt,
185
+ outputs=[result, seed],
186
+ fn=generate,
187
+ cache_examples=False,
188
+ )
189
+
190
+ gr.on(
191
+ triggers=[
192
+ prompt.submit,
193
+ run_button.click,
194
+ ],
195
+ fn=generate,
196
+ inputs=[
197
+ prompt,
198
+ seed,
199
+ width,
200
+ height,
201
+ guidance_scale,
202
+ randomize_seed,
203
+ style_selection,
204
+ ],
205
+ outputs=[result, seed],
206
+ api_name="run",
207
+ )
208
 
209
  if __name__ == "__main__":
210
+ demo.queue(max_size=40).launch()