Color_express / app.py
gaur3009's picture
Create app.py
1e3319b verified
raw
history blame contribute delete
2.93 kB
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from PIL import Image, ImageFilter
import numpy as np
import gradio as gr
import cv2
# Load pre-trained Stable Diffusion model (frozen part)
model_id = "runwayml/stable-diffusion-v1-5"
controlnet_id = "lllyasviel/control_v11p_sd15_canny" # ControlNet for edge detection-based control
# Load ControlNet model (trainable part)
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
# Load Stable Diffusion pipeline with ControlNet
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, controlnet=controlnet, torch_dtype=torch.float16
)
# Use an efficient scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# Move pipeline to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
# Function to generate control image (edge detection using Canny filter)
def generate_control_image(input_image_path):
image = cv2.imread(input_image_path, cv2.IMREAD_GRAYSCALE)
edges = cv2.Canny(image, 100, 200) # Apply Canny edge detection
control_image = Image.fromarray(edges).convert("L")
control_image = control_image.resize((512, 512)) # Resize to match model requirements
control_image.save("control_image.jpg")
return "control_image.jpg"
# Function to apply color change
def apply_color_change(input_image, prompt):
# Save input image temporarily
input_image_path = "input_image.jpg"
input_image.save(input_image_path)
# Generate control image (edges)
control_image_path = generate_control_image(input_image_path)
# Load processed input and control images
input_image = Image.open(input_image_path).convert("RGB").resize((512, 512))
control_image = Image.open(control_image_path).convert("L")
# Generate the new image using the pipeline
generator = torch.manual_seed(42) # For reproducibility
output_image = pipe(
prompt=prompt,
image=input_image,
control_image=control_image,
generator=generator,
num_inference_steps=30
).images[0]
output_image.save("output_color_changed.png")
return "output_color_changed.png"
# Gradio interface
def gradio_interface(input_image, prompt):
output_image_path = apply_color_change(input_image, prompt)
return output_image_path
# Launch the Gradio interface with drag and drop
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Upload your image"), # Drag and drop feature
gr.Textbox(label="Enter prompt", placeholder="e.g. A hoodie with blue and white design"),
],
outputs=gr.Image(label="Color Changed Output"),
title="AI-Powered Clothing Color Changer",
description="Upload an image of clothing, enter a prompt, and get a redesigned color version.",
)
interface.launch()