Ghibli-Art / app.py
Sask07's picture
Update app.py
152de82 verified
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import CannyDetector
import gc
import numpy as np
from PIL import Image
# Initialize the canny edge detector
canny = CannyDetector()
def create_pipeline():
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Load ControlNet
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
torch_dtype=torch.float16,
use_safetensors=True
)
# Load pipeline
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"nitrosocke/Ghibli-Diffusion",
controlnet=controlnet,
torch_dtype=torch.float16,
safety_checker=None
)
if torch.cuda.is_available():
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing(1)
return pipe
# Create pipeline
pipe = create_pipeline()
def enhance_prompt(base_prompt):
"""Add detailed Ghibli-specific style keywords to the prompt"""
style_elements = [
"Studio Ghibli masterpiece",
"hand-painted animation style",
"Hayao Miyazaki inspired",
"soft detailed lighting",
"gentle color palette",
"delicate line art",
"atmospheric background"
]
return f"{', '.join(style_elements)}, {base_prompt}, high quality, detailed features, smooth lines"
def preprocess_image(image):
"""Preprocess image to ensure consistent dimensions"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Resize image to a maximum size while maintaining aspect ratio
max_size = 512
ratio = max_size / max(image.size)
new_size = tuple([int(x * ratio) for x in image.size])
image = image.resize(new_size, Image.Resampling.LANCZOS)
# Create a new image with padding to make it square
new_image = Image.new("RGB", (max_size, max_size), (255, 255, 255))
offset = ((max_size - new_size[0]) // 2, (max_size - new_size[1]) // 2)
new_image.paste(image, offset)
return new_image
def process_image_for_canny(image):
"""Optimize image for better edge detection"""
# Convert to numpy array if it's a PIL Image
if isinstance(image, Image.Image):
image = np.array(image)
# Ensure image is in RGB format
if len(image.shape) == 2: # If grayscale
image = np.stack([image] * 3, axis=-1)
return image
def generate_image(input_image, prompt):
try:
if input_image is None:
raise gr.Error("Please upload an image")
if not prompt:
raise gr.Error("Please enter a prompt")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Preprocess the input image first
preprocessed_image = preprocess_image(input_image)
# Process image for better edge detection
processed_image = process_image_for_canny(preprocessed_image)
# Generate canny edge detection with optimized parameters
canny_image = canny(processed_image, low_threshold=100, high_threshold=200)
# Enhance prompt with style elements
enhanced_prompt = enhance_prompt(prompt)
# Generate image with optimized parameters
with torch.inference_mode():
output_image = pipe(
prompt=enhanced_prompt,
image=canny_image,
num_inference_steps=30, # Increased for better detail
guidance_scale=8.5, # Increased for stronger adherence to prompt
controlnet_conditioning_scale=1.0, # Balance between control and creativity
negative_prompt="blurry, low quality, broken lines, distorted features, asymmetrical"
).images[0]
return output_image, enhanced_prompt
except Exception as e:
raise gr.Error(str(e))
finally:
# Clear memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Create Gradio interface
with gr.Blocks(css="style.css") as demo:
gr.Markdown("""
# 🎨 Enhanced Ghibli Art Generator
Transform your images into the magical style of Studio Ghibli with improved detail and quality
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload Image",
elem_id="input-image"
)
prompt = gr.Textbox(
label="Enter your prompt",
placeholder="A peaceful mountain cabin surrounded by nature...",
elem_id="prompt-input"
)
with gr.Row():
generate_btn = gr.Button("🎨 Generate", variant="primary", elem_id="generate-btn")
clear_btn = gr.Button("🗑️ Clear", elem_id="clear-btn")
with gr.Column():
output_image = gr.Image(label="Generated Image", elem_id="output-image")
used_prompt = gr.Textbox(
label="Enhanced Prompt",
elem_id="enhanced-prompt",
interactive=False
)
gr.Markdown("""
## 🌟 Improved Features
- Enhanced detail with 30 inference steps
- Stronger style adherence with 8.5 guidance scale
- Optimized edge detection
- Rich Ghibli-style prompt enhancement
## 💡 Tips
- Use clear, well-lit images
- Be specific in your prompts
- Include mood and atmosphere descriptions
- Expect 15-20 seconds for generation
""")
# Set up event handlers
generate_btn.click(
fn=generate_image,
inputs=[input_image, prompt],
outputs=[output_image, used_prompt]
)
clear_btn.click(
lambda: [None, ""],
outputs=[output_image, used_prompt]
)
# Launch with minimal queue and custom queue message
demo.queue(max_size=5, concurrency_count=1).launch(
share=False,
debug=True,
show_error=True
)