emirhanbilgic commited on
Commit
9aa62fe
·
verified ·
1 Parent(s): 9726d40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import spaces
2
-
3
  import gradio as gr
4
  import cv2
5
  import numpy as np
@@ -42,7 +41,6 @@ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
42
  torch_dtype=torch.float16,
43
  ).to(device)
44
 
45
-
46
  # Set the K_EULER scheduler
47
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
48
 
@@ -80,12 +78,15 @@ def resize_image(image, max_size=1536):
80
  new_size = (int(w * scale), int(h * scale))
81
  image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
82
  return image
83
-
84
  @spaces.GPU(duration=60)
85
  # Function to inpaint the hair area using ControlNet
86
  def inpaint_hair(image, prompt):
87
- # Resize the input image if necessary
88
- image = resize_image(image)
 
 
 
89
  # Segment hair to get the mask
90
  mask = segment_hair(image)
91
  # Convert to PIL image for the inpainting pipeline
@@ -102,16 +103,18 @@ def inpaint_hair(image, prompt):
102
  # Generate inpainted image
103
  generator = torch.Generator(device).manual_seed(42)
104
  negative_prompt = "lowres, bad quality, poor quality"
105
- output = pipe(
106
- prompt=prompt,
107
- negative_prompt=negative_prompt,
108
- image=image_pil,
109
- mask_image=mask_pil,
110
- control_image=inpaint_condition,
111
- num_inference_steps=25,
112
- guidance_scale=7.5,
113
- generator=generator
114
- ).images[0]
 
 
115
 
116
  return np.array(output)
117
 
@@ -129,4 +132,4 @@ iface = gr.Interface(
129
  )
130
 
131
  if __name__ == "__main__":
132
- iface.launch()
 
1
  import spaces
 
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
 
41
  torch_dtype=torch.float16,
42
  ).to(device)
43
 
 
44
  # Set the K_EULER scheduler
45
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
46
 
 
78
  new_size = (int(w * scale), int(h * scale))
79
  image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
80
  return image
81
+
82
  @spaces.GPU(duration=60)
83
  # Function to inpaint the hair area using ControlNet
84
  def inpaint_hair(image, prompt):
85
+ # Only resize the input image if it's larger than 1536 in any dimension
86
+ h, w = image.shape[:2]
87
+ if max(h, w) > 1536:
88
+ image = resize_image(image)
89
+
90
  # Segment hair to get the mask
91
  mask = segment_hair(image)
92
  # Convert to PIL image for the inpainting pipeline
 
103
  # Generate inpainted image
104
  generator = torch.Generator(device).manual_seed(42)
105
  negative_prompt = "lowres, bad quality, poor quality"
106
+
107
+ with torch.cuda.amp.autocast(): # Enable automatic mixed precision
108
+ output = pipe(
109
+ prompt=prompt,
110
+ negative_prompt=negative_prompt,
111
+ image=image_pil,
112
+ mask_image=mask_pil,
113
+ control_image=inpaint_condition,
114
+ num_inference_steps=25,
115
+ guidance_scale=7.5,
116
+ generator=generator
117
+ ).images[0]
118
 
119
  return np.array(output)
120
 
 
132
  )
133
 
134
  if __name__ == "__main__":
135
+ iface.launch()