start3406 commited on
Commit
32ac17a
·
verified ·
1 Parent(s): f8eb849

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -190
app.py CHANGED
@@ -1,25 +1,27 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, set_seed
4
- # 导入 AutoPipelineForText2Image 以便兼容不同模型
5
- from diffusers import AutoPipelineForText2Image
6
  import openai
7
  import os
8
  import time
9
- import traceback # For detailed error logging
 
 
10
 
11
  # ---- Configuration & API Key ----
12
- # Check for OpenAI API Key in Hugging Face Secrets
13
- api_key = os.environ.get("OPENAI_API_KEY")
14
- openai_client = None
15
- openai_available = False
16
 
17
  if api_key:
18
  try:
19
- # Starting with openai v1, client instantiation is preferred
 
20
  openai_client = openai.OpenAI(api_key=api_key)
21
- # Simple test to check if the key is valid (optional, but good)
22
- # openai_client.models.list() # This call might incur small cost/quota usage
23
  openai_available = True
24
  print("OpenAI API key found and client initialized.")
25
  except Exception as e:
@@ -28,95 +30,93 @@ if api_key:
28
  else:
29
  print("WARNING: OPENAI_API_KEY secret not found. Prompt enhancement via OpenAI is disabled.")
30
 
31
- # Force CPU usage
32
- device = "cpu"
33
  print(f"Using device: {device}")
34
 
 
 
 
 
 
 
 
 
 
 
 
35
  # ---- Model Loading (CPU Focused) ----
36
 
37
- # 1. 语音转文本模型 (Whisper) - 加分项
38
  asr_pipeline = None
39
  try:
40
  print("Loading ASR pipeline (Whisper) on CPU...")
41
- # Force CPU usage with device=-1 or device="cpu"
42
- # 使用 fp16 会更快但需要GPU,CPU上用 float32
43
  asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device, torch_dtype=torch.float32)
44
  print("ASR pipeline loaded successfully on CPU.")
45
  except Exception as e:
46
- print(f"Could not load ASR pipeline: {e}. Voice input will be disabled.")
47
- traceback.print_exc() # Print full traceback for debugging
48
 
49
  # 2. 文本到图像模型 (nota-ai/bk-sdm-tiny) - 资源友好模型
50
- image_generator_pipe = None
51
- # 使用 nota-ai/bk-sdm-tiny 模型
52
- model_id = "nota-ai/bk-sdm-tiny"
53
  try:
54
  print(f"Loading Text-to-Image pipeline ({model_id}) on CPU...")
55
  print("NOTE: Using a small model for resource efficiency. Image quality and details may differ from larger models.")
56
  # 使用 AutoPipelineForText2Image 自动识别模型类型
57
- image_generator_pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float32)
58
- image_generator_pipe = image_generator_pipe.to(device)
59
  print(f"Text-to-Image pipeline ({model_id}) loaded successfully on CPU.")
60
  except Exception as e:
61
  print(f"CRITICAL: Could not load Text-to-Image pipeline ({model_id}): {e}. Image generation will fail.")
62
- traceback.print_exc() # Print full traceback for debugging
63
- # Define a dummy object to prevent crashes later if loading failed
64
- class DummyPipe:
65
- def __call__(self, *args, **kwargs):
66
- raise RuntimeError(f"Text-to-Image model failed to load: {e}")
67
- image_generator_pipe = DummyPipe()
68
 
69
 
70
  # ---- Core Function Definitions ----
71
 
72
- # Step 1: Prompt-to-Prompt (using OpenAI API)
73
- def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boost="photorealistic, highly detailed"):
74
- """Uses OpenAI API to enhance the short description."""
 
 
 
 
75
  if not openai_available or not openai_client:
76
- # Fallback or error if OpenAI key is missing/invalid
77
  print("OpenAI not available. Returning original prompt with modifiers.")
78
- # Basic fallback prompt enhancement
79
- if short_prompt:
80
- return f"{short_prompt}, {style_modifier}, {quality_boost}"
81
- else:
82
- # If short prompt is empty, fallback should also indicate error
83
- raise gr.Error("Input description cannot be empty.")
84
-
85
-
86
- if not short_prompt:
87
- # Return an error message formatted for Gradio output
88
- raise gr.Error("Input description cannot be empty.")
89
 
90
- # Construct the prompt for the OpenAI model
91
- system_message = (
92
  "You are an expert prompt engineer for AI image generation models. "
93
  "Expand the user's short description into a detailed, vivid, and coherent prompt, suitable for smaller, faster text-to-image models. "
94
  "Focus on clear subjects, objects, and main scene elements. "
95
  "Incorporate the requested style and quality keywords naturally, but keep the overall prompt concise enough for smaller models. Avoid conversational text."
96
- # Adjusting guidance for smaller models
97
  )
98
- user_message = (
99
- f"Enhance this description: \"{short_prompt}\". "
100
  f"Style: '{style_modifier}'. Quality: '{quality_boost}'."
101
  )
102
 
103
- print(f"Sending request to OpenAI for prompt enhancement: {short_prompt}")
104
 
105
  try:
106
  response = openai_client.chat.completions.create(
107
- model="gpt-3.5-turbo", # Cost-effective choice
108
  messages=[
109
  {"role": "system", "content": system_message},
110
  {"role": "user", "content": user_message},
111
  ],
112
- temperature=0.7, # Controls creativity vs predictability
113
- max_tokens=100, # Limit output length - reduced for potentially shorter prompts for smaller models
114
- n=1, # Generate one response
115
- stop=None # Let the model decide when to stop
116
  )
117
- enhanced_prompt = response.choices[0].message.content.strip()
118
  print("OpenAI enhancement successful.")
119
- # Basic cleanup: remove potential quotes around the whole response
120
  if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
121
  enhanced_prompt = enhanced_prompt[1:-1]
122
  return enhanced_prompt
@@ -135,207 +135,204 @@ def enhance_prompt_openai(short_prompt, style_modifier="cinematic", quality_boos
135
  raise gr.Error(f"Prompt enhancement failed: {e}")
136
 
137
 
138
- # Step 2: Prompt-to-Image (CPU)
139
- def generate_image_cpu(prompt, negative_prompt, guidance_scale, num_inference_steps):
140
- """Generates image using the loaded model on CPU."""
141
- # 检查加载的模型是否是期望的pipeline类型或DummyPipe
142
- if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
143
- # If it's a DummyPipe or None for some reason
144
- if isinstance(image_generator_pipe, DummyPipe):
145
- # DummyPipe will raise its own error when called, so just let it
146
- pass # The call below will raise the intended error
147
- else:
148
- # Handle unexpected case where pipe is not loaded correctly
149
- raise gr.Error("Image generation pipeline is not available (failed to load model).")
150
-
151
 
 
152
  if not prompt or "[Error:" in prompt or "Error:" in prompt:
153
- # Check if the prompt itself is an error message from the previous step
154
  raise gr.Error("Cannot generate image due to invalid or missing prompt.")
155
 
156
- print(f"Generating image on CPU for prompt: {prompt[:100]}...") # Log truncated prompt
157
- # Note: Negative prompt and guidance scale might have less impact or behave differently
158
- # on very small models.
159
- print(f"Negative prompt: {negative_prompt}") # Will likely be ignored by tiny model
160
- print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}") # Steps might be fixed internally by tiny model
161
 
162
- start_time = time.time()
163
 
164
  try:
165
- # Use torch.inference_mode() or torch.no_grad() for efficiency
166
  with torch.no_grad():
167
- # Seed for reproducibility (optional, but good practice)
168
- # generator = torch.Generator(device=device).manual_seed(int(time.time())) # Tiny model might not use generator param
169
- # Call the pipeline - assuming standard parameters are accepted
170
  output = image_generator_pipe(
171
  prompt=prompt,
172
- # It's possible tiny models ignore some parameters, but passing them is safer
173
  negative_prompt=negative_prompt,
174
  guidance_scale=float(guidance_scale),
175
  num_inference_steps=int(num_inference_steps),
176
- # generator=generator, # Omit if tiny model pipeline doesn't accept it
177
- # height and width might need to be specified or limited for tiny models
178
  # height=..., width=...
179
  )
180
 
181
- # Access the generated image(s). Assuming standard diffusers output structure (.images[0])
182
  if hasattr(output, 'images') and isinstance(output.images, list) and len(output.images) > 0:
183
- image = output.images[0] # Access the first image
184
  else:
185
- # Handle cases where output format is different (less common for AutoPipelines)
186
  print("Warning: Pipeline output format unexpected. Attempting to use the output directly.")
187
- image = output # Assume output is the image
 
 
 
 
 
 
188
 
189
- end_time = time.time()
190
  print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds (using {model_id}).")
191
  return image
192
  except Exception as e:
193
  print(f"Error during image generation on CPU ({model_id}): {e}")
194
  traceback.print_exc()
195
- # Propagate error to Gradio UI
196
  raise gr.Error(f"Image generation failed on CPU ({model_id}): {e}")
197
 
198
 
199
  # Bonus: Voice-to-Text (CPU)
200
- def transcribe_audio(audio_file_path):
201
- """Transcribes audio to text using Whisper on CPU."""
 
202
  if not asr_pipeline:
203
- # This case should ideally be handled by hiding the control, but double-check
204
  return "[Error: ASR model not loaded]", audio_file_path
205
  if audio_file_path is None:
206
- return "", audio_file_path # No audio input
 
207
 
208
  print(f"Transcribing audio file: {audio_file_path} on CPU...")
209
- start_time = time.time()
210
  try:
211
- # Ensure the pipeline uses the correct device (should be CPU based on loading)
212
- # Ensure input is in expected format for Whisper pipeline (filepath or audio array)
213
- if isinstance(audio_file_path, tuple): # Handle case where Gradio might pass tuple
214
- # Assuming tuple is (samplerate, numpy_array), need to save to temp file or process directly
215
- # For simplicity with type="filepath", assume it passes path directly
216
- print("Warning: Audio input was tuple, expecting filepath. This might fail.")
217
- # Attempting to process numpy array if it's the second element
218
- if isinstance(audio_file_path[1], torch.Tensor) or isinstance(audio_file_path[1], list) or isinstance(audio_file_path[1], (int, float)):
219
- # This path is complex, sticking to filepath assumption for now
220
- pass # Let the pipeline call below handle potential error
221
- audio_input_for_pipeline = audio_file_path # Pass original tuple, let pipeline handle
222
- else:
223
- audio_input_for_pipeline = audio_file_path # Expected filepath
224
-
225
- transcription = asr_pipeline(audio_input_for_pipeline)["text"]
226
- end_time = time.time()
227
  print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
228
  print(f"Transcription result: {transcription}")
229
  return transcription, audio_file_path
230
  except Exception as e:
231
  print(f"Error during audio transcription on CPU: {e}")
232
  traceback.print_exc()
233
- # Return error message in the expected tuple format
234
  return f"[Error: Transcription failed: {e}]", audio_file_path
235
 
236
 
237
  # ---- Gradio Application Flow ----
238
 
239
- def process_input(input_text, audio_file, style_choice, quality_choice, neg_prompt, guidance, steps):
240
- """Main function triggered by Gradio button."""
241
- final_text_input = ""
242
- enhanced_prompt = ""
243
- generated_image = None
244
- status_message = "" # To gather status/errors for the prompt box
245
-
246
- # 1. Determine Input (Text or Audio)
 
 
 
 
 
 
 
 
247
  if input_text and input_text.strip():
248
  final_text_input = input_text.strip()
249
  print(f"Using text input: '{final_text_input}'")
250
  elif audio_file is not None:
251
  print("Processing audio input...")
252
  try:
253
- # transcribe_audio handles different Gradio audio output types potentially
254
  transcribed_text, _ = transcribe_audio(audio_file)
255
 
256
  if "[Error:" in transcribed_text:
257
- # Display transcription error clearly
258
  status_message = transcribed_text
259
  print(status_message)
260
- return status_message, None # Return error in prompt field, no image
261
- elif transcribed_text:
262
- final_text_input = transcribed_text
263
  print(f"Using transcribed audio input: '{final_text_input}'")
264
  else:
265
- status_message = "[Error: Audio input received but transcription was empty.]"
266
  print(status_message)
267
- return status_message, None # Return error
268
  except Exception as e:
269
  status_message = f"[Unexpected Audio Transcription Error: {e}]"
270
  print(status_message)
271
  traceback.print_exc()
272
- return status_message, None # Return error
273
 
274
  else:
275
  status_message = "[Error: No input provided. Please enter text or record audio.]"
276
  print(status_message)
277
- return status_message, None # Return error
278
 
279
- # 2. Enhance Prompt (using OpenAI if available)
280
  if final_text_input:
281
  try:
282
  enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
283
- status_message = enhanced_prompt # Display the prompt initially
284
  print(f"Enhanced prompt: {enhanced_prompt}")
285
  except gr.Error as e:
286
- # Catch Gradio-specific errors from enhancement function
287
  status_message = f"[Prompt Enhancement Error: {e}]"
288
  print(status_message)
289
- # Return the error, no image generation attempt
290
  return status_message, None
291
  except Exception as e:
292
- # Catch any other unexpected errors
293
  status_message = f"[Unexpected Prompt Enhancement Error: {e}]"
294
  print(status_message)
295
  traceback.print_exc()
296
  return status_message, None
297
 
298
- # 3. Generate Image (if prompt is valid)
299
- # Check if the enhanced prompt step resulted in an error message
300
  if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
301
  try:
302
- # Show "Generating..." message while waiting
303
  gr.Info(f"Starting image generation on CPU using {model_id}. This should be faster than full SD, but might still take time.")
304
  generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
305
  gr.Info("Image generation complete!")
306
  except gr.Error as e:
307
- # Catch Gradio errors from generation function
308
- # Prepend original enhanced prompt to the error message for context
309
  status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
310
  print(f"Image Generation Error: {e}")
311
- generated_image = None # Ensure image is None on error
312
  except Exception as e:
313
- # Catch any other unexpected errors
314
  status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
315
  print(f"Unexpected Image Generation Error: {e}")
316
  traceback.print_exc()
317
- generated_image = None # Ensure image is None on error
318
 
319
  else:
320
- # If prompt enhancement failed, status_message already contains the error
321
- # In this case, we just return the existing status_message and None image
322
  print("Skipping image generation due to prompt enhancement failure.")
323
 
324
 
325
- # 4. Return results to Gradio UI
326
- # Return the status message (enhanced prompt or error) and the image (or None if error)
327
  return status_message, generated_image
328
 
329
 
330
  # ---- Gradio Interface Construction ----
331
 
332
- style_options = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
333
- quality_options = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
334
 
335
- # Adjust steps/guidance defaults for a smaller model, still might be ignored by some pipelines
336
- default_steps = 20
337
- max_steps = 40 # Adjusted max steps
338
- default_guidance = 5.0 # Adjusted default guidance
 
339
 
340
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
341
  gr.Markdown("# AI Image Generator (CPU Version - Using Small Model)")
@@ -343,95 +340,96 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
343
  "**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
344
  f"to create a detailed prompt, then generates an image using a **small model ({model_id}) on the CPU**."
345
  )
346
- # Add specific warning about CPU speed and potential resource issues for this specific model
347
  gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Note: Using a small model for better compatibility on CPU. Generation should be faster than full Stable Diffusion, but quality/details may differ.</p>")
348
  gr.HTML("<p style='color:red;font-weight:bold;'>⏰ CPU generation can still take 1-5 minutes per image depending on load and model specifics.</p>")
349
 
350
 
351
- # Display OpenAI availability status
352
  if not openai_available:
353
  gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")
354
  else:
355
  gr.Markdown("**Note:** OpenAI API key found. Prompt will be enhanced using OpenAI.")
356
 
357
- # Display Model loading status
358
- # Check against AutoPipelineForText2Image type
359
- if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
360
  gr.Markdown(f"**CRITICAL:** Image generation model ({model_id}) failed to load. Image generation is disabled. Check Space logs for details.")
361
 
362
-
363
  with gr.Row():
364
  with gr.Column(scale=1):
365
- # --- Inputs ---
366
  inp_text = gr.Textbox(label="Enter short description", placeholder="e.g., A cute robot drinking coffee on Mars")
367
 
368
- # Only show Audio input if ASR model loaded successfully
369
  if asr_pipeline:
 
370
  inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
371
  else:
372
  gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
373
- # Using gr.State as a placeholder that holds None
374
  inp_audio = gr.State(None)
375
 
376
- # --- Controls ---
377
- # Note: These controls might have less impact than on larger models
378
  gr.Markdown("*(Optional controls - Note: Their impact might vary on this small model)*")
379
- # Control 1: Dropdown
380
  inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic")
381
- # Control 2: Radio
382
  inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
383
- # Control 3: Textbox (Negative Prompt)
384
  inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark, signature, deformed")
385
- # Control 4: Slider (Guidance Scale)
386
- inp_guidance = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=default_guidance, label="Guidance Scale (CFG)") # Lower max guidance
387
- # Control 5: Slider (Inference Steps) - Adjusted max/default
388
- inp_steps = gr.Slider(minimum=5, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})") # Lower min steps
389
 
390
- # --- Action Button ---
391
- # Disable button if model failed to load
392
- btn_generate = gr.Button("Generate Image", variant="primary", interactive=isinstance(image_generator_pipe, AutoPipelineForText2Image))
393
 
394
  with gr.Column(scale=1):
395
- # --- Outputs ---
396
- out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # Show prompt or error status here
397
- out_image = gr.Image(label="Generated Image", type="pil", show_label=True) # Ensure label is shown
398
 
399
- # --- Event Handling ---
400
- # Define inputs list carefully, handling potentially invisible audio input
401
  inputs_list = [inp_text]
 
402
  if asr_pipeline:
403
  inputs_list.append(inp_audio)
404
  else:
405
- inputs_list.append(inp_audio) # Pass the gr.State(None) placeholder
 
406
 
407
  inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
408
 
409
- # Link button click to processing function
410
  btn_generate.click(
411
  fn=process_input,
412
  inputs=inputs_list,
413
  outputs=[out_prompt, out_image]
414
  )
415
 
416
- # Clear text input if audio is used (only if ASR is available)
417
  if asr_pipeline:
418
- def clear_text_on_audio_change(audio_data):
419
- # Check if audio_data is not None or empty (depending on how Gradio signals recording)
420
  if audio_data is not None:
421
  print("Audio input detected, clearing text box.")
422
- return "" # Clear text box
423
- # If audio_data becomes None (e.g., recording cleared), don't clear text
424
  return gr.update()
425
 
426
- # .change event fires when the value changes, including becoming None if cleared
427
  inp_audio.change(fn=clear_text_on_audio_change, inputs=inp_audio, outputs=inp_text, api_name="clear_text_on_audio")
428
 
429
 
430
  # ---- Application Launch ----
431
  if __name__ == "__main__":
432
- # Final check before launch
433
- # Check against AutoPipelineForText2Image type
434
- if not isinstance(image_generator_pipe, AutoPipelineForText2Image):
435
  print("\n" + "="*50)
436
  print("CRITICAL WARNING:")
437
  print(f"Image generation model ({model_id}) failed to load during startup.")
@@ -440,6 +438,6 @@ if __name__ == "__main__":
440
  print("="*50 + "\n")
441
 
442
 
443
- # Launch the Gradio app
444
- # Running on 0.0.0.0 is necessary for Hugging Face Spaces
445
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, set_seed
4
+ from diffusers import AutoPipelineForText2Image # 导入 AutoPipelineForText2Image 以便兼容不同模型
 
5
  import openai
6
  import os
7
  import time
8
+ import traceback # 用于详细错误日志记录
9
+ from typing import Optional, Tuple, Union # 用于类型提示
10
+ from PIL import Image # 用于图像类型提示
11
 
12
  # ---- Configuration & API Key ----
13
+ # 检查 Hugging Face Secrets 中是否存在 OpenAI API Key
14
+ api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
15
+ openai_client: Optional[openai.OpenAI] = None
16
+ openai_available: bool = False
17
 
18
  if api_key:
19
  try:
20
+ # 使用 openai v1 版本,推荐实例化 client
21
+ # openai.api_key = api_key # 老版本写法,新版本推荐下方实例化
22
  openai_client = openai.OpenAI(api_key=api_key)
23
+ # 可选:简单的测试检查密钥是否有效(可能产生少量费用/占用配额)
24
+ # openai_client.models.list()
25
  openai_available = True
26
  print("OpenAI API key found and client initialized.")
27
  except Exception as e:
 
30
  else:
31
  print("WARNING: OPENAI_API_KEY secret not found. Prompt enhancement via OpenAI is disabled.")
32
 
33
+ # 强制使用 CPU
34
+ device: str = "cpu"
35
  print(f"Using device: {device}")
36
 
37
+ # 定义 DummyPipe 类,用于模型加载失败时的占位符
38
+ # 需要在模型加载块之前定义
39
+ class DummyPipe:
40
+ """
41
+ A placeholder class used when the actual image generation pipeline fails to load.
42
+ Its __call__ method raises a RuntimeError indicating the failure.
43
+ """
44
+ def __call__(self, *args, **kwargs) -> None:
45
+ # 这个错误消息会被调用者 (process_input -> generate_image_cpu) 捕获并显示
46
+ raise RuntimeError("Image generation pipeline is not available (failed to load model).")
47
+
48
  # ---- Model Loading (CPU Focused) ----
49
 
50
+ # 1. 语音转文本模型 (Whisper) - 可选功能
51
  asr_pipeline = None
52
  try:
53
  print("Loading ASR pipeline (Whisper) on CPU...")
54
+ # 强制使用 CPU,并使用 float32 类型以兼容 CPU
 
55
  asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device, torch_dtype=torch.float32)
56
  print("ASR pipeline loaded successfully on CPU.")
57
  except Exception as e:
58
+ print(f"Could not load ASR pipeline (Whisper): {e}. Voice input will be disabled.")
59
+ traceback.print_exc() # 打印完整 traceback 以便于调试
60
 
61
  # 2. 文本到图像模型 (nota-ai/bk-sdm-tiny) - 资源友好模型
62
+ image_generator_pipe: Union[AutoPipelineForText2Image, DummyPipe] = DummyPipe() # 初始化为 DummyPipe
63
+ model_id: str = "nota-ai/bk-sdm-tiny" # 使用 nota-ai/bk-sdm-tiny 模型
 
64
  try:
65
  print(f"Loading Text-to-Image pipeline ({model_id}) on CPU...")
66
  print("NOTE: Using a small model for resource efficiency. Image quality and details may differ from larger models.")
67
  # 使用 AutoPipelineForText2Image 自动识别模型类型
68
+ pipeline_instance = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float32)
69
+ image_generator_pipe = pipeline_instance.to(device)
70
  print(f"Text-to-Image pipeline ({model_id}) loaded successfully on CPU.")
71
  except Exception as e:
72
  print(f"CRITICAL: Could not load Text-to-Image pipeline ({model_id}): {e}. Image generation will fail.")
73
+ traceback.print_exc() # 打印完整 traceback 以便于调试
74
+ # image_generator_pipe 保持为初始化的 DummyPipe()
 
 
 
 
75
 
76
 
77
  # ---- Core Function Definitions ----
78
 
79
+ # Step 1: Prompt Enhancement (using OpenAI API or Fallback)
80
+ def enhance_prompt_openai(short_prompt: str, style_modifier: str = "cinematic", quality_boost: str = "photorealistic, highly detailed") -> str:
81
+ """使用 OpenAI API (如果可用) 增强用户输入的简短描述。"""
82
+ if not short_prompt or not short_prompt.strip():
83
+ # 如果输入为空,直接抛出错误
84
+ raise gr.Error("Input description cannot be empty.")
85
+
86
  if not openai_available or not openai_client:
87
+ # 如果 OpenAI 不可用,使用基本备用方案
88
  print("OpenAI not available. Returning original prompt with modifiers.")
89
+ return f"{short_prompt.strip()}, {style_modifier}, {quality_boost}"
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # 如果 OpenAI 可用,构建并发送请求
92
+ system_message: str = (
93
  "You are an expert prompt engineer for AI image generation models. "
94
  "Expand the user's short description into a detailed, vivid, and coherent prompt, suitable for smaller, faster text-to-image models. "
95
  "Focus on clear subjects, objects, and main scene elements. "
96
  "Incorporate the requested style and quality keywords naturally, but keep the overall prompt concise enough for smaller models. Avoid conversational text."
 
97
  )
98
+ user_message: str = (
99
+ f"Enhance this description: \"{short_prompt.strip()}\". "
100
  f"Style: '{style_modifier}'. Quality: '{quality_boost}'."
101
  )
102
 
103
+ print(f"Sending request to OpenAI for prompt enhancement: '{short_prompt.strip()}'")
104
 
105
  try:
106
  response = openai_client.chat.completions.create(
107
+ model="gpt-3.5-turbo", # 成本效益高的选择
108
  messages=[
109
  {"role": "system", "content": system_message},
110
  {"role": "user", "content": user_message},
111
  ],
112
+ temperature=0.7, # 控制创造性
113
+ max_tokens=100, # 限制输出长度
114
+ n=1, # 生成一个响应
115
+ stop=None # 让模型决定何时停止
116
  )
117
+ enhanced_prompt: str = response.choices[0].message.content.strip()
118
  print("OpenAI enhancement successful.")
119
+ # 基本清理:移除可能出现在整个响应外部的引号
120
  if enhanced_prompt.startswith('"') and enhanced_prompt.endswith('"'):
121
  enhanced_prompt = enhanced_prompt[1:-1]
122
  return enhanced_prompt
 
135
  raise gr.Error(f"Prompt enhancement failed: {e}")
136
 
137
 
138
+ # Step 2: Image Generation (CPU)
139
+ def generate_image_cpu(prompt: str, negative_prompt: str, guidance_scale: float, num_inference_steps: int) -> Image.Image:
140
+ """ CPU 上使用加载的模型生成图像。"""
141
+ # 检查模型是否成功加载 (是否是 DummyPipe)
142
+ if isinstance(image_generator_pipe, DummyPipe):
143
+ # 如果是 DummyPipe,调用它会抛出加载失败的错误
144
+ image_generator_pipe() # 这会直接抛出 intended 的错误
 
 
 
 
 
 
145
 
146
+ # 如果不是 DummyPipe,它应该是 AutoPipelineForText2Image 的实例
147
  if not prompt or "[Error:" in prompt or "Error:" in prompt:
148
+ # 检查提示词本身是否是来自前一步的错误信息
149
  raise gr.Error("Cannot generate image due to invalid or missing prompt.")
150
 
151
+ print(f"Generating image on CPU for prompt: {prompt[:100]}...") # 记录截断的提示词
152
+ # 注意:负面提示词、guidance_scale num_inference_steps 对小型模型影响可能较小或行为不同
153
+ print(f"Negative prompt: {negative_prompt}")
154
+ print(f"Guidance scale: {guidance_scale}, Steps: {num_inference_steps}")
 
155
 
156
+ start_time: float = time.time()
157
 
158
  try:
159
+ # 使用 torch.no_grad() 提高效率
160
  with torch.no_grad():
161
+ # 调用 pipeline
162
+ # 传递标准参数,即使小型模型可能忽略其中一些
 
163
  output = image_generator_pipe(
164
  prompt=prompt,
 
165
  negative_prompt=negative_prompt,
166
  guidance_scale=float(guidance_scale),
167
  num_inference_steps=int(num_inference_steps),
168
+ # generator height/width 参数可能需要根据具体小型模型进行调整或省略
169
+ # generator=torch.Generator(device=device).manual_seed(int(time.time())),
170
  # height=..., width=...
171
  )
172
 
173
+ # 获取生成的图像。假设标准的 diffusers 输出结构 (.images[0])
174
  if hasattr(output, 'images') and isinstance(output.images, list) and len(output.images) > 0:
175
+ image: Image.Image = output.images[0] # 获取第一张图片
176
  else:
177
+ # 处理输出格式不同的情况 (AutoPipelines 较少出现)
178
  print("Warning: Pipeline output format unexpected. Attempting to use the output directly.")
179
+ # 尝试将整个输出视为图像,但这可能需要根据实际模型输出类型进行调整
180
+ if isinstance(output, Image.Image):
181
+ image = output
182
+ else:
183
+ # 如果输出既没有 .images 也不是 PIL Image,则认为是失败
184
+ raise RuntimeError(f"Image generation pipeline returned unexpected output type: {type(output)}")
185
+
186
 
187
+ end_time: float = time.time()
188
  print(f"Image generated successfully on CPU in {end_time - start_time:.2f} seconds (using {model_id}).")
189
  return image
190
  except Exception as e:
191
  print(f"Error during image generation on CPU ({model_id}): {e}")
192
  traceback.print_exc()
193
+ # 将错误传播给 Gradio UI
194
  raise gr.Error(f"Image generation failed on CPU ({model_id}): {e}")
195
 
196
 
197
  # Bonus: Voice-to-Text (CPU)
198
+ def transcribe_audio(audio_file_path: Optional[str]) -> Tuple[str, Optional[str]]:
199
+ """使用 Whisper CPU 上将音频转录为文本。"""
200
+ # 检查 ASR pipeline 是否加载成功
201
  if not asr_pipeline:
202
+ # 返回错误信息 tuple
203
  return "[Error: ASR model not loaded]", audio_file_path
204
  if audio_file_path is None:
205
+ # 没有音频输入,返回空字符串
206
+ return "", audio_file_path
207
 
208
  print(f"Transcribing audio file: {audio_file_path} on CPU...")
209
+ start_time: float = time.time()
210
  try:
211
+ # 假设 audio_file_path 是一个字符串路径,因为 Gradio Audio 组件 type="filepath"
212
+ # asr_pipeline 期望输入是文件路径字符串或音频数据数组
213
+ # 这里假设 type="filepath" 传递的是文件路径
214
+ transcription: str = asr_pipeline(audio_file_path)["text"]
215
+ end_time: float = time.time()
 
 
 
 
 
 
 
 
 
 
 
216
  print(f"Transcription successful in {end_time - start_time:.2f} seconds.")
217
  print(f"Transcription result: {transcription}")
218
  return transcription, audio_file_path
219
  except Exception as e:
220
  print(f"Error during audio transcription on CPU: {e}")
221
  traceback.print_exc()
222
+ # 返回错误信息 tuple
223
  return f"[Error: Transcription failed: {e}]", audio_file_path
224
 
225
 
226
  # ---- Gradio Application Flow ----
227
 
228
+ def process_input(
229
+ input_text: str,
230
+ audio_file: Optional[str], # 根据 type="filepath" 是字符串路径或 None
231
+ style_choice: str,
232
+ quality_choice: str,
233
+ neg_prompt: str,
234
+ guidance: float,
235
+ steps: int
236
+ ) -> Tuple[str, Optional[Image.Image]]:
237
+ """由 Gradio 按钮触发的主处理函数。"""
238
+ final_text_input: str = ""
239
+ enhanced_prompt: str = ""
240
+ generated_image: Optional[Image.Image] = None
241
+ status_message: str = "" # 用于在 prompt 输出框显示状态/错误
242
+
243
+ # 1. 确定输入 (文本或语音)
244
  if input_text and input_text.strip():
245
  final_text_input = input_text.strip()
246
  print(f"Using text input: '{final_text_input}'")
247
  elif audio_file is not None:
248
  print("Processing audio input...")
249
  try:
 
250
  transcribed_text, _ = transcribe_audio(audio_file)
251
 
252
  if "[Error:" in transcribed_text:
253
+ # 清晰显示转录错误
254
  status_message = transcribed_text
255
  print(status_message)
256
+ return status_message, None # prompt 字段返回错误,不生成图像
257
+ elif transcribed_text and transcribed_text.strip(): # 确保转录结果不为空
258
+ final_text_input = transcribed_text.strip()
259
  print(f"Using transcribed audio input: '{final_text_input}'")
260
  else:
261
+ status_message = "[Error: Audio input received but transcription was empty or whitespace.]"
262
  print(status_message)
263
+ return status_message, None # 返回错误
264
  except Exception as e:
265
  status_message = f"[Unexpected Audio Transcription Error: {e}]"
266
  print(status_message)
267
  traceback.print_exc()
268
+ return status_message, None # 返回错误
269
 
270
  else:
271
  status_message = "[Error: No input provided. Please enter text or record audio.]"
272
  print(status_message)
273
+ return status_message, None # 返回错误
274
 
275
+ # 2. 增强提示词 (使用 OpenAI 如果可用)
276
  if final_text_input:
277
  try:
278
  enhanced_prompt = enhance_prompt_openai(final_text_input, style_choice, quality_choice)
279
+ status_message = enhanced_prompt # 初始显示增强后的提示词
280
  print(f"Enhanced prompt: {enhanced_prompt}")
281
  except gr.Error as e:
282
+ # 捕获来自增强函数的 Gradio 特定的错误
283
  status_message = f"[Prompt Enhancement Error: {e}]"
284
  print(status_message)
285
+ # 返回错误,不尝试生成图像
286
  return status_message, None
287
  except Exception as e:
288
+ # 捕获其他意外错误
289
  status_message = f"[Unexpected Prompt Enhancement Error: {e}]"
290
  print(status_message)
291
  traceback.print_exc()
292
  return status_message, None
293
 
294
+ # 3. 生成图像 (如果提示词有效)
295
+ # 检查增强提示词步骤是否返回了错误信息
296
  if enhanced_prompt and not status_message.startswith("[Error:") and not status_message.startswith("[Prompt Enhancement Error:"):
297
  try:
298
+ # 显示“正在生成...”消息
299
  gr.Info(f"Starting image generation on CPU using {model_id}. This should be faster than full SD, but might still take time.")
300
  generated_image = generate_image_cpu(enhanced_prompt, neg_prompt, guidance, steps)
301
  gr.Info("Image generation complete!")
302
  except gr.Error as e:
303
+ # 捕获来自生成函数的 Gradio 错误
304
+ # 在错误消息前加上原始的增强提示词以便提供上下文
305
  status_message = f"{enhanced_prompt}\n\n[Image Generation Error: {e}]"
306
  print(f"Image Generation Error: {e}")
307
+ generated_image = None # 确保错误时图像为 None
308
  except Exception as e:
309
+ # 捕获其他意外错误
310
  status_message = f"{enhanced_prompt}\n\n[Unexpected Image Generation Error: {e}]"
311
  print(f"Unexpected Image Generation Error: {e}")
312
  traceback.print_exc()
313
+ generated_image = None # 确保错误时图像为 None
314
 
315
  else:
316
+ # 如果提示词增强失败,status_message 已经包含了错误信息
317
+ # 此时,我们只返回现有的 status_message None 图像
318
  print("Skipping image generation due to prompt enhancement failure.")
319
 
320
 
321
+ # 4. 将结果返回给 Gradio UI
322
+ # 返回状态消息 (增强提示词或错误) 和图像 (如果出错则为 None)
323
  return status_message, generated_image
324
 
325
 
326
  # ---- Gradio Interface Construction ----
327
 
328
+ style_options: list[str] = ["cinematic", "photorealistic", "anime", "fantasy art", "cyberpunk", "steampunk", "watercolor", "illustration", "low poly"]
329
+ quality_options: list[str] = ["highly detailed", "sharp focus", "intricate details", "4k", "masterpiece", "best quality", "professional lighting"]
330
 
331
+ # 为小型模型调整步数/Guidance Scale 默认值和最大值,注意它们的影响可能不如大型模型显著
332
+ default_steps: int = 20
333
+ max_steps: int = 40 # 调整最大步数
334
+ default_guidance: float = 5.0 # 调整默认 Guidance Scale
335
+ max_guidance: float = 10.0 # 调整最大 Guidance Scale
336
 
337
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
338
  gr.Markdown("# AI Image Generator (CPU Version - Using Small Model)")
 
340
  "**Enter a short description or use voice input.** The app uses OpenAI (if API key is provided) "
341
  f"to create a detailed prompt, then generates an image using a **small model ({model_id}) on the CPU**."
342
  )
343
+ # 添加关于 CPU 速度和模型特性的警告
344
  gr.HTML("<p style='color:orange;font-weight:bold;'>⚠️ Note: Using a small model for better compatibility on CPU. Generation should be faster than full Stable Diffusion, but quality/details may differ.</p>")
345
  gr.HTML("<p style='color:red;font-weight:bold;'>⏰ CPU generation can still take 1-5 minutes per image depending on load and model specifics.</p>")
346
 
347
 
348
+ # 显示 OpenAI 可用状态
349
  if not openai_available:
350
  gr.Markdown("**Note:** OpenAI API key not found or invalid. Prompt enhancement will use a basic fallback.")
351
  else:
352
  gr.Markdown("**Note:** OpenAI API key found. Prompt will be enhanced using OpenAI.")
353
 
354
+ # 显示模型加载状态 - 修改检查逻辑
355
+ # 检查 image_generator_pipe 是否是 DummyPipe,如果是则表示加载失败
356
+ if isinstance(image_generator_pipe, DummyPipe):
357
  gr.Markdown(f"**CRITICAL:** Image generation model ({model_id}) failed to load. Image generation is disabled. Check Space logs for details.")
358
 
 
359
  with gr.Row():
360
  with gr.Column(scale=1):
361
+ # --- 输入控件 ---
362
  inp_text = gr.Textbox(label="Enter short description", placeholder="e.g., A cute robot drinking coffee on Mars")
363
 
364
+ # 只有当 ASR 模型加载成功时才显示音频输入控件
365
  if asr_pipeline:
366
+ # type="filepath" 会将录音保存为临时文件并传递文件路径
367
  inp_audio = gr.Audio(sources=["microphone"], type="filepath", label="Or record your idea (clears text box if used)")
368
  else:
369
  gr.Markdown("**Voice input disabled:** Whisper model failed to load.")
370
+ # 使用 gr.State 作为占位符,其值为 None
371
  inp_audio = gr.State(None)
372
 
373
+ # --- 控制参数 ---
374
+ # 注意:这些控制参数对小型模型的影响可能不如对大型模型显著
375
  gr.Markdown("*(Optional controls - Note: Their impact might vary on this small model)*")
376
+ # 控制 1: 下拉选择框
377
  inp_style = gr.Dropdown(label="Base Style", choices=style_options, value="cinematic")
378
+ # 控制 2: 单选按钮组
379
  inp_quality = gr.Radio(label="Quality Boost", choices=quality_options, value="highly detailed")
380
+ # 控制 3: 文本框 (负面提示词)
381
  inp_neg_prompt = gr.Textbox(label="Negative Prompt (optional)", placeholder="e.g., blurry, low quality, text, watermark, signature, deformed")
382
+ # 控制 4: 滑块 (Guidance Scale)
383
+ inp_guidance = gr.Slider(minimum=1.0, maximum=max_guidance, step=0.5, value=default_guidance, label="Guidance Scale (CFG)") # 降低最大值和默认值
384
+ # 控制 5: 滑块 (Inference Steps) - 调整最大值和默认值
385
+ inp_steps = gr.Slider(minimum=5, maximum=max_steps, step=1, value=default_steps, label=f"Inference Steps (lower = faster but less detail, max {max_steps})") # 调整最小值、最大值和默认值
386
 
387
+ # --- 操作按钮 ---
388
+ # 如果模型加载失败 (是 DummyPipe),则禁用按钮
389
+ btn_generate = gr.Button("Generate Image", variant="primary", interactive=not isinstance(image_generator_pipe, DummyPipe))
390
 
391
  with gr.Column(scale=1):
392
+ # --- 输出控件 ---
393
+ out_prompt = gr.Textbox(label="Generated Prompt / Status", interactive=False, lines=5) # 显示提示词或错误状态
394
+ out_image = gr.Image(label="Generated Image", type="pil", show_label=True) # 确保显示标签
395
 
396
+ # --- 事件处理 ---
397
+ # 仔细定义输入列表,处理可能不可见的音频输入控件
398
  inputs_list = [inp_text]
399
+ # 如果 ASR 可用,将 inp_audio 加入输入列表
400
  if asr_pipeline:
401
  inputs_list.append(inp_audio)
402
  else:
403
+ # 如果 ASR 不可用,将 gr.State(None) 占位符加入输入列表
404
+ inputs_list.append(inp_audio)
405
 
406
  inputs_list.extend([inp_style, inp_quality, inp_neg_prompt, inp_guidance, inp_steps])
407
 
408
+ # 将按钮点击事件连接到主处理函数
409
  btn_generate.click(
410
  fn=process_input,
411
  inputs=inputs_list,
412
  outputs=[out_prompt, out_image]
413
  )
414
 
415
+ # 如果使用了音频输入,则清空文本输入框 (仅当 ASR 可用时)
416
  if asr_pipeline:
417
+ def clear_text_on_audio_change(audio_data: Optional[str]) -> Union[str, gr.update]:
418
+ # 检查 audio_data 是否不是 None 或空
419
  if audio_data is not None:
420
  print("Audio input detected, clearing text box.")
421
+ return "" # 清空文本框
422
+ # 如果 audio_data 变为 None (例如,录音被清除),则不改变文本框
423
  return gr.update()
424
 
425
+ # .change 事件在值改变时触发,包括变为 None (如果控件支持)
426
  inp_audio.change(fn=clear_text_on_audio_change, inputs=inp_audio, outputs=inp_text, api_name="clear_text_on_audio")
427
 
428
 
429
  # ---- Application Launch ----
430
  if __name__ == "__main__":
431
+ # 最终检查并打印警告,基于 image_generator_pipe 是否为 DummyPipe
432
+ if isinstance(image_generator_pipe, DummyPipe):
 
433
  print("\n" + "="*50)
434
  print("CRITICAL WARNING:")
435
  print(f"Image generation model ({model_id}) failed to load during startup.")
 
438
  print("="*50 + "\n")
439
 
440
 
441
+ # 启动 Gradio 应用
442
+ # Hugging Face Spaces 中,需要监听 0.0.0.0 7860 端口
443
  demo.launch(share=False, server_name="0.0.0.0", server_port=7860)