Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import spaces
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import cv2
|
5 |
import numpy as np
|
@@ -42,7 +41,6 @@ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
|
42 |
torch_dtype=torch.float16,
|
43 |
).to(device)
|
44 |
|
45 |
-
|
46 |
# Set the K_EULER scheduler
|
47 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
48 |
|
@@ -80,12 +78,15 @@ def resize_image(image, max_size=1536):
|
|
80 |
new_size = (int(w * scale), int(h * scale))
|
81 |
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
|
82 |
return image
|
83 |
-
|
84 |
@spaces.GPU(duration=60)
|
85 |
# Function to inpaint the hair area using ControlNet
|
86 |
def inpaint_hair(image, prompt):
|
87 |
-
#
|
88 |
-
|
|
|
|
|
|
|
89 |
# Segment hair to get the mask
|
90 |
mask = segment_hair(image)
|
91 |
# Convert to PIL image for the inpainting pipeline
|
@@ -102,16 +103,18 @@ def inpaint_hair(image, prompt):
|
|
102 |
# Generate inpainted image
|
103 |
generator = torch.Generator(device).manual_seed(42)
|
104 |
negative_prompt = "lowres, bad quality, poor quality"
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
|
116 |
return np.array(output)
|
117 |
|
@@ -129,4 +132,4 @@ iface = gr.Interface(
|
|
129 |
)
|
130 |
|
131 |
if __name__ == "__main__":
|
132 |
-
iface.launch()
|
|
|
1 |
import spaces
|
|
|
2 |
import gradio as gr
|
3 |
import cv2
|
4 |
import numpy as np
|
|
|
41 |
torch_dtype=torch.float16,
|
42 |
).to(device)
|
43 |
|
|
|
44 |
# Set the K_EULER scheduler
|
45 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
46 |
|
|
|
78 |
new_size = (int(w * scale), int(h * scale))
|
79 |
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
|
80 |
return image
|
81 |
+
|
82 |
@spaces.GPU(duration=60)
|
83 |
# Function to inpaint the hair area using ControlNet
|
84 |
def inpaint_hair(image, prompt):
|
85 |
+
# Only resize the input image if it's larger than 1536 in any dimension
|
86 |
+
h, w = image.shape[:2]
|
87 |
+
if max(h, w) > 1536:
|
88 |
+
image = resize_image(image)
|
89 |
+
|
90 |
# Segment hair to get the mask
|
91 |
mask = segment_hair(image)
|
92 |
# Convert to PIL image for the inpainting pipeline
|
|
|
103 |
# Generate inpainted image
|
104 |
generator = torch.Generator(device).manual_seed(42)
|
105 |
negative_prompt = "lowres, bad quality, poor quality"
|
106 |
+
|
107 |
+
with torch.cuda.amp.autocast(): # Enable automatic mixed precision
|
108 |
+
output = pipe(
|
109 |
+
prompt=prompt,
|
110 |
+
negative_prompt=negative_prompt,
|
111 |
+
image=image_pil,
|
112 |
+
mask_image=mask_pil,
|
113 |
+
control_image=inpaint_condition,
|
114 |
+
num_inference_steps=25,
|
115 |
+
guidance_scale=7.5,
|
116 |
+
generator=generator
|
117 |
+
).images[0]
|
118 |
|
119 |
return np.array(output)
|
120 |
|
|
|
132 |
)
|
133 |
|
134 |
if __name__ == "__main__":
|
135 |
+
iface.launch()
|