imaginpaint / app.py
aiqtech's picture
Update app.py
818d397 verified
import gradio as gr
import spaces
import torch
import sys
import traceback
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
# Add better error handling
def print_error(error_message):
print("=" * 50)
print(f"ERROR: {error_message}")
print("-" * 50)
print(traceback.format_exc())
print("=" * 50)
try:
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
except Exception as e:
print_error(f"Failed to import required modules: {e}")
print("Ensure the controlnet_union and pipeline_fill_sd_xl modules are available")
sys.exit(1)
MODELS = {
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
# Replace the problematic translation model with a simpler function
def translate_if_korean(text):
# Just log that Korean was detected but return the original text
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in text):
print(f"Korean text detected: {text}")
print("Translation is disabled - using original text")
return text
# Wrap with try/except to catch any model loading errors
try:
config_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="diffusion_pytorch_model_promax.safetensors",
)
except Exception as e:
print_error(f"Failed to load model configuration: {e}")
print("Attempting to use direct model loading as fallback...")
# We'll set these to None to indicate failure, and handle it below
config_file = None
config = None
controlnet_model = None
model_file = None
state_dict = load_state_dict(model_file)
# Fix for the _load_pretrained_model method
# We need to handle the case where the method signature might have changed
try:
# Try the original approach first
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
except TypeError:
# If it fails due to missing 'loaded_keys' argument
# We'll try a more compatible approach
print("Using alternative model loading approach...")
# Try the updated method signature (includes loaded_keys)
# First get the keys from the state dict
loaded_keys = list(state_dict.keys())
try:
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
)
except Exception as e:
print(f"Advanced loading failed: {e}")
print("Falling back to direct loading...")
# As a last resort, try to load the model directly
try:
# Just load the model directly
controlnet_model.load_state_dict(state_dict)
model = controlnet_model
except Exception as load_err:
print(f"Direct loading failed: {load_err}")
# Final fallback: try to initialize from pretrained
model = ControlNetModel_Union.from_pretrained(
"xinsir/controlnet-union-sdxl-1.0",
torch_dtype=torch.float16
)
# Convert model to GPU with float16
model.to(device="cuda", dtype=torch.float16)
# Define flag to track if we're in fallback mode (no controlnet)
using_fallback = False
try:
# Try to load the VAE
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
# Set up the pipeline with controlnet if available
if model is not None:
pipe = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=model,
variant="fp16",
).to("cuda")
else:
# Fallback to regular StableDiffusionXLPipeline if controlnet failed
print("Loading without ControlNet as fallback")
using_fallback = True
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
variant="fp16",
).to("cuda")
# Set scheduler
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
except Exception as e:
print_error(f"Failed to initialize pipeline: {e}")
# If we get here, we couldn't load even the fallback pipeline
# We'll define a dummy fill_image function below that just returns the input image
@spaces.GPU
def fill_image(prompt, image, model_selection):
# Check if we're in fallback mode (no ControlNet)
global using_fallback
# Get the translated prompt
translated_prompt = translate_if_korean(prompt)
try:
# Extract the source image and mask
source = image["background"]
mask = image["layers"][0]
# Create a binary mask from the alpha channel
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
# Handle based on whether we're using regular pipeline or ControlNet
if using_fallback:
# Using regular StableDiffusionXLPipeline without ControlNet
print("Using fallback pipeline without ControlNet")
# For fallback mode, we'll just use the regular pipeline
# and inpaint as best we can
try:
# Generate a new image based on the prompt
generated = pipe(
prompt=translated_prompt,
negative_prompt="low quality, worst quality, bad anatomy, bad composition, poor, low effort",
num_inference_steps=30,
guidance_scale=7.5,
).images[0]
# Composite the generated image into the masked area
result = source.copy()
result.paste(generated, (0, 0), binary_mask)
# Return both the original and the result
yield source, result
except Exception as e:
print_error(f"Fallback generation failed: {e}")
# If even this fails, just return the source image
yield source, source
else:
# Normal operation with ControlNet
# Prepare the controlnet input image
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
# Encode the prompt
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(translated_prompt, "cuda", True)
# Generate the image
for image in pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=cnet_image,
):
yield image, cnet_image
# Composite the final result
image = image.convert("RGBA")
cnet_image.paste(image, (0, 0), binary_mask)
yield source, cnet_image
except Exception as e:
print_error(f"Error during image generation: {e}")
# Return the original image in case of error
if 'source' in locals():
yield source, source
else:
print("Critical error: Source image not available")
# Create a blank image if we can't get the source
from PIL import Image
blank = Image.new('RGB', (512, 512), color=(255, 255, 255))
yield blank, blank
def clear_result():
return gr.update(value=None)
css = """
footer {
visibility: hidden;
}
.sample-image {
display: flex;
justify-content: center;
margin-top: 20px;
}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
info="Describe what to fill in the mask area (Korean or English)",
lines=3,
)
with gr.Column():
model_selection = gr.Dropdown(
choices=list(MODELS.keys()),
value="RealVisXL V5.0 Lightning",
label="Model",
)
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label="Input Image",
crop_size=(1024, 1024),
layers=False
)
result = ImageSlider(
interactive=False,
label="Generated Image",
)
use_as_input_button = gr.Button("Use as Input Image", visible=False)
# Add sample image
with gr.Row(elem_classes="sample-image"):
sample_image = gr.Image("sample.png", label="Sample Image", height=256, width=256)
def use_output_as_input(output_image):
return gr.update(value=output_image[1])
use_as_input_button.click(
fn=use_output_as_input,
inputs=[result],
outputs=[input_image]
)
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=use_as_input_button,
).then(
fn=fill_image,
inputs=[prompt, input_image, model_selection],
outputs=result,
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=use_as_input_button,
)
prompt.submit(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=use_as_input_button,
).then(
fn=fill_image,
inputs=[prompt, input_image, model_selection],
outputs=result,
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=use_as_input_button,
)
demo.launch(share=False)