Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import sys | |
import traceback | |
from diffusers import AutoencoderKL, TCDScheduler | |
from diffusers.models.model_loading_utils import load_state_dict | |
from gradio_imageslider import ImageSlider | |
from huggingface_hub import hf_hub_download | |
# Add better error handling | |
def print_error(error_message): | |
print("=" * 50) | |
print(f"ERROR: {error_message}") | |
print("-" * 50) | |
print(traceback.format_exc()) | |
print("=" * 50) | |
try: | |
from controlnet_union import ControlNetModel_Union | |
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline | |
except Exception as e: | |
print_error(f"Failed to import required modules: {e}") | |
print("Ensure the controlnet_union and pipeline_fill_sd_xl modules are available") | |
sys.exit(1) | |
MODELS = { | |
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning", | |
} | |
# Replace the problematic translation model with a simpler function | |
def translate_if_korean(text): | |
# Just log that Korean was detected but return the original text | |
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in text): | |
print(f"Korean text detected: {text}") | |
print("Translation is disabled - using original text") | |
return text | |
# Wrap with try/except to catch any model loading errors | |
try: | |
config_file = hf_hub_download( | |
"xinsir/controlnet-union-sdxl-1.0", | |
filename="config_promax.json", | |
) | |
config = ControlNetModel_Union.load_config(config_file) | |
controlnet_model = ControlNetModel_Union.from_config(config) | |
model_file = hf_hub_download( | |
"xinsir/controlnet-union-sdxl-1.0", | |
filename="diffusion_pytorch_model_promax.safetensors", | |
) | |
except Exception as e: | |
print_error(f"Failed to load model configuration: {e}") | |
print("Attempting to use direct model loading as fallback...") | |
# We'll set these to None to indicate failure, and handle it below | |
config_file = None | |
config = None | |
controlnet_model = None | |
model_file = None | |
state_dict = load_state_dict(model_file) | |
# Fix for the _load_pretrained_model method | |
# We need to handle the case where the method signature might have changed | |
try: | |
# Try the original approach first | |
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model( | |
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0" | |
) | |
except TypeError: | |
# If it fails due to missing 'loaded_keys' argument | |
# We'll try a more compatible approach | |
print("Using alternative model loading approach...") | |
# Try the updated method signature (includes loaded_keys) | |
# First get the keys from the state dict | |
loaded_keys = list(state_dict.keys()) | |
try: | |
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model( | |
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys | |
) | |
except Exception as e: | |
print(f"Advanced loading failed: {e}") | |
print("Falling back to direct loading...") | |
# As a last resort, try to load the model directly | |
try: | |
# Just load the model directly | |
controlnet_model.load_state_dict(state_dict) | |
model = controlnet_model | |
except Exception as load_err: | |
print(f"Direct loading failed: {load_err}") | |
# Final fallback: try to initialize from pretrained | |
model = ControlNetModel_Union.from_pretrained( | |
"xinsir/controlnet-union-sdxl-1.0", | |
torch_dtype=torch.float16 | |
) | |
# Convert model to GPU with float16 | |
model.to(device="cuda", dtype=torch.float16) | |
# Define flag to track if we're in fallback mode (no controlnet) | |
using_fallback = False | |
try: | |
# Try to load the VAE | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 | |
).to("cuda") | |
# Set up the pipeline with controlnet if available | |
if model is not None: | |
pipe = StableDiffusionXLFillPipeline.from_pretrained( | |
"SG161222/RealVisXL_V5.0_Lightning", | |
torch_dtype=torch.float16, | |
vae=vae, | |
controlnet=model, | |
variant="fp16", | |
).to("cuda") | |
else: | |
# Fallback to regular StableDiffusionXLPipeline if controlnet failed | |
print("Loading without ControlNet as fallback") | |
using_fallback = True | |
from diffusers import StableDiffusionXLPipeline | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"SG161222/RealVisXL_V5.0_Lightning", | |
torch_dtype=torch.float16, | |
vae=vae, | |
variant="fp16", | |
).to("cuda") | |
# Set scheduler | |
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config) | |
except Exception as e: | |
print_error(f"Failed to initialize pipeline: {e}") | |
# If we get here, we couldn't load even the fallback pipeline | |
# We'll define a dummy fill_image function below that just returns the input image | |
def fill_image(prompt, image, model_selection): | |
# Check if we're in fallback mode (no ControlNet) | |
global using_fallback | |
# Get the translated prompt | |
translated_prompt = translate_if_korean(prompt) | |
try: | |
# Extract the source image and mask | |
source = image["background"] | |
mask = image["layers"][0] | |
# Create a binary mask from the alpha channel | |
alpha_channel = mask.split()[3] | |
binary_mask = alpha_channel.point(lambda p: p > 0 and 255) | |
# Handle based on whether we're using regular pipeline or ControlNet | |
if using_fallback: | |
# Using regular StableDiffusionXLPipeline without ControlNet | |
print("Using fallback pipeline without ControlNet") | |
# For fallback mode, we'll just use the regular pipeline | |
# and inpaint as best we can | |
try: | |
# Generate a new image based on the prompt | |
generated = pipe( | |
prompt=translated_prompt, | |
negative_prompt="low quality, worst quality, bad anatomy, bad composition, poor, low effort", | |
num_inference_steps=30, | |
guidance_scale=7.5, | |
).images[0] | |
# Composite the generated image into the masked area | |
result = source.copy() | |
result.paste(generated, (0, 0), binary_mask) | |
# Return both the original and the result | |
yield source, result | |
except Exception as e: | |
print_error(f"Fallback generation failed: {e}") | |
# If even this fails, just return the source image | |
yield source, source | |
else: | |
# Normal operation with ControlNet | |
# Prepare the controlnet input image | |
cnet_image = source.copy() | |
cnet_image.paste(0, (0, 0), binary_mask) | |
# Encode the prompt | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = pipe.encode_prompt(translated_prompt, "cuda", True) | |
# Generate the image | |
for image in pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
image=cnet_image, | |
): | |
yield image, cnet_image | |
# Composite the final result | |
image = image.convert("RGBA") | |
cnet_image.paste(image, (0, 0), binary_mask) | |
yield source, cnet_image | |
except Exception as e: | |
print_error(f"Error during image generation: {e}") | |
# Return the original image in case of error | |
if 'source' in locals(): | |
yield source, source | |
else: | |
print("Critical error: Source image not available") | |
# Create a blank image if we can't get the source | |
from PIL import Image | |
blank = Image.new('RGB', (512, 512), color=(255, 255, 255)) | |
yield blank, blank | |
def clear_result(): | |
return gr.update(value=None) | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
.sample-image { | |
display: flex; | |
justify-content: center; | |
margin-top: 20px; | |
} | |
""" | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
info="Describe what to fill in the mask area (Korean or English)", | |
lines=3, | |
) | |
with gr.Column(): | |
model_selection = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value="RealVisXL V5.0 Lightning", | |
label="Model", | |
) | |
run_button = gr.Button("Generate") | |
with gr.Row(): | |
input_image = gr.ImageMask( | |
type="pil", | |
label="Input Image", | |
crop_size=(1024, 1024), | |
layers=False | |
) | |
result = ImageSlider( | |
interactive=False, | |
label="Generated Image", | |
) | |
use_as_input_button = gr.Button("Use as Input Image", visible=False) | |
# Add sample image | |
with gr.Row(elem_classes="sample-image"): | |
sample_image = gr.Image("sample.png", label="Sample Image", height=256, width=256) | |
def use_output_as_input(output_image): | |
return gr.update(value=output_image[1]) | |
use_as_input_button.click( | |
fn=use_output_as_input, | |
inputs=[result], | |
outputs=[input_image] | |
) | |
run_button.click( | |
fn=clear_result, | |
inputs=None, | |
outputs=result, | |
).then( | |
fn=lambda: gr.update(visible=False), | |
inputs=None, | |
outputs=use_as_input_button, | |
).then( | |
fn=fill_image, | |
inputs=[prompt, input_image, model_selection], | |
outputs=result, | |
).then( | |
fn=lambda: gr.update(visible=True), | |
inputs=None, | |
outputs=use_as_input_button, | |
) | |
prompt.submit( | |
fn=clear_result, | |
inputs=None, | |
outputs=result, | |
).then( | |
fn=lambda: gr.update(visible=False), | |
inputs=None, | |
outputs=use_as_input_button, | |
).then( | |
fn=fill_image, | |
inputs=[prompt, input_image, model_selection], | |
outputs=result, | |
).then( | |
fn=lambda: gr.update(visible=True), | |
inputs=None, | |
outputs=use_as_input_button, | |
) | |
demo.launch(share=False) |