File size: 4,890 Bytes
dcda2e9
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64da1cb
c97a8b1
f43f7c7
31e6ce9
f43f7c7
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
31e6ce9
f43f7c7
c97a8b1
 
 
9726d40
c97a8b1
31e6ce9
f43f7c7
c97a8b1
64da1cb
 
 
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef9154e
 
e3fd16e
ef9154e
 
 
 
 
 
9aa62fe
c97a8b1
 
9aa62fe
 
 
 
 
c97a8b1
 
 
9570e3d
71aa31e
 
c97a8b1
 
 
 
 
f43f7c7
c97a8b1
 
31e6ce9
9726d40
9aa62fe
31e6ce9
 
 
 
 
 
 
 
 
 
c97a8b1
 
 
 
 
 
 
 
 
 
 
 
 
ef9154e
c97a8b1
 
 
31e6ce9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import spaces
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.python._framework_bindings import image as image_module
_Image = image_module.Image
from mediapipe.python._framework_bindings import image_frame
_ImageFormat = image_frame.ImageFormat

import torch
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel
from PIL import Image
from compel import Compel
from diffusers import EulerDiscreteScheduler

# Device configuration
device = torch.device("cpu")  # Ensure everything is set to run on CPU

# Constants for colors
BG_COLOR = (0, 0, 0, 255)  # gray with full opacity
MASK_COLOR = (255, 255, 255, 255)  # white with full opacity

# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='emirhan.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
                                       output_category_mask=True)

# Initialize ControlNet inpainting pipeline
controlnet = ControlNetModel.from_pretrained(
    'lllyasviel/control_v11p_sd15_inpaint',
    torch_dtype=torch.float32,  # Use float32 for CPU
).to(device)

pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    'runwayml/stable-diffusion-v1-5',
    safety_checker=None,
    controlnet=controlnet,
    torch_dtype=torch.float32,  # Use float32 for CPU
).to(device)

# Set the K_EULER scheduler
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)

# Function to segment hair and generate mask
def segment_hair(image):
    rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
    rgba_image[:, :, 3] = 0  # Set alpha channel to empty

    # Create MP Image object from numpy array
    mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image)

    # Create the image segmenter
    with vision.ImageSegmenter.create_from_options(options) as segmenter:
        # Retrieve the masks for the segmented image
        segmentation_result = segmenter.segment(mp_image)
        category_mask = segmentation_result.category_mask

        # Generate solid color images for showing the output segmentation mask.
        image_data = mp_image.numpy_view()
        fg_image = np.zeros(image_data.shape, dtype=np.uint8)
        fg_image[:] = MASK_COLOR
        bg_image = np.zeros(image_data.shape, dtype=np.uint8)
        bg_image[:] = BG_COLOR

        condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2
        output_image = np.where(condition, fg_image, bg_image)

        return output_image  # Return the RGBA mask

# Function to resize image while maintaining aspect ratio
def resize_image(image, max_size=1536):
    h, w = image.shape[:2]
    if max(h, w) > max_size:
        scale = max_size / max(h, w)
        new_size = (int(w * scale), int(h * scale))
        image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
    return image

# Function to inpaint the hair area using ControlNet
def inpaint_hair(image, prompt):
    # Only resize the input image if it's larger than 1536 in any dimension
    h, w = image.shape[:2]
    if max(h, w) > 1536:
        image = resize_image(image)
    
    # Segment hair to get the mask
    mask = segment_hair(image)
    # Convert to PIL image for the inpainting pipeline
    image_pil = Image.fromarray(image)
    mask_pil = Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY))
    mask_pil = mask_pil.convert("L")  # Ensure it's a single-channel (grayscale) image

    # Prepare the inpainting condition
    image_np = np.array(image_pil).astype(np.float32) / 255.0
    mask_np = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0
    image_np[mask_np > 0.5] = -1.0  # Set as masked pixel
    inpaint_condition = torch.from_numpy(np.expand_dims(image_np, 0).transpose(0, 3, 1, 2)).to(device)

    # Generate inpainted image
    generator = torch.manual_seed(42)
    negative_prompt = "lowres, bad quality, poor quality"
    
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=image_pil,
        mask_image=mask_pil,
        control_image=inpaint_condition,
        num_inference_steps=25,
        guidance_scale=7.5,
        generator=generator
    ).images[0]

    return np.array(output)

# Gradio interface
iface = gr.Interface(
    fn=inpaint_hair,
    inputs=[
        gr.Image(type="numpy"),
        gr.Textbox(label="Prompt", placeholder="Describe the desired inpainting result...")
    ],
    outputs=gr.Image(type="numpy"),
    title="Hair Inpainting with ControlNet",
    description="Upload an image, and provide a prompt to inpaint the hair area using ControlNet.",
    examples=[["example.jpeg", "dreadlocks"], ["example2.jpg", "pink hair"]]
)

if __name__ == "__main__":
    iface.launch()