xingyang1 commited on
Commit
642c115
·
verified ·
1 Parent(s): d71ed99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -19
app.py CHANGED
@@ -11,6 +11,7 @@ from huggingface_hub import hf_hub_download
11
  from safetensors.torch import load_file
12
  from gradio_imageslider import ImageSlider
13
  import spaces
 
14
 
15
  # Helper function to load model from Hugging Face
16
  def load_model_by_name(arch_name, checkpoint_path, device):
@@ -31,7 +32,7 @@ def load_model_by_name(arch_name, checkpoint_path, device):
31
  # Image processing function
32
  def process_image(image, model, device):
33
  if model is None:
34
- return None
35
 
36
  # Preprocess the image
37
  image_np = np.array(image)[..., ::-1] / 255
@@ -45,32 +46,40 @@ def process_image(image, model, device):
45
  image_tensor = transform({'image': image_np})['image']
46
  image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device)
47
 
48
- with torch.no_grad(): # Disable autograd since we don't need gradients on CPU
49
  pred_disp, _ = model(image_tensor)
 
50
 
51
- # Ensure the depth map is in the correct shape before colorization
52
- pred_disp_np = pred_disp.cpu().detach().numpy()[0, 0, :, :] # Remove extra singleton dimensions
53
 
54
  # Normalize depth map
55
- pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min())
56
 
57
- # Colorize depth map
58
  cmap = "Spectral_r"
59
- depth_colored = colorize_depth_maps(pred_disp[None, ..., None], 0, 1, cmap=cmap).squeeze() # Ensure correct dimension
60
-
61
- # Convert to uint8 for image display
62
  depth_colored = (depth_colored * 255).astype(np.uint8)
63
-
64
- # Convert to HWC format (height, width, channels)
65
  depth_colored_hwc = chw2hwc(depth_colored)
66
 
67
- # Resize to match the original image dimensions (height, width)
 
 
 
 
 
 
 
 
 
68
  h, w = image_np.shape[:2]
69
  depth_colored_hwc = cv2.resize(depth_colored_hwc, (w, h), cv2.INTER_LINEAR)
 
70
 
71
- # Convert to a PIL image
72
- depth_image = Image.fromarray(depth_colored_hwc)
73
- return image, depth_image
 
74
 
75
  # Gradio interface function with GPU support
76
  @spaces.GPU
@@ -105,17 +114,20 @@ def gradio_interface(image):
105
  model = model.to(device) # 确保模型在正确的设备上
106
 
107
  if model is None:
108
- return None
109
 
110
  # Process image and return output
111
- depth_image = process_image(image, model, device)
112
- return depth_image
113
 
114
  # Create Gradio interface
115
  iface = gr.Interface(
116
  fn=gradio_interface,
117
  inputs=gr.Image(type="pil"), # Only image input, no mode selection
118
- outputs = ImageSlider(label="Depth slider", type="pil", slider_color="pink"), # Depth image out with a slider
 
 
 
119
  title="Depth Estimation Demo",
120
  description="Upload an image to see the depth estimation results. Our model is running on GPU for faster processing.",
121
  examples=["1.jpg", "2.jpg", "4.png", "5.jpg", "6.jpg"],
 
11
  from safetensors.torch import load_file
12
  from gradio_imageslider import ImageSlider
13
  import spaces
14
+ import tempfile
15
 
16
  # Helper function to load model from Hugging Face
17
  def load_model_by_name(arch_name, checkpoint_path, device):
 
32
  # Image processing function
33
  def process_image(image, model, device):
34
  if model is None:
35
+ return None, None, None, None
36
 
37
  # Preprocess the image
38
  image_np = np.array(image)[..., ::-1] / 255
 
46
  image_tensor = transform({'image': image_np})['image']
47
  image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device)
48
 
49
+ with torch.no_grad():
50
  pred_disp, _ = model(image_tensor)
51
+ torch.cuda.empty_cache()
52
 
53
+ # Convert depth map to numpy
54
+ pred_disp_np = pred_disp.cpu().detach().numpy()[0, 0, :, :]
55
 
56
  # Normalize depth map
57
+ pred_disp_normalized = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min())
58
 
59
+ # Colorized depth map
60
  cmap = "Spectral_r"
61
+ depth_colored = colorize_depth_maps(pred_disp_normalized[None, ..., None], 0, 1, cmap=cmap).squeeze()
 
 
62
  depth_colored = (depth_colored * 255).astype(np.uint8)
 
 
63
  depth_colored_hwc = chw2hwc(depth_colored)
64
 
65
+ # Gray depth map
66
+ depth_gray = (pred_disp_normalized * 255).astype(np.uint8)
67
+ depth_gray_hwc = np.stack([depth_gray] * 3, axis=-1) # Convert to 3-channel grayscale
68
+
69
+ # Save raw depth map as a temporary npy file
70
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as temp_file:
71
+ np.save(temp_file.name, pred_disp_normalized)
72
+ depth_raw_path = temp_file.name
73
+
74
+ # Resize outputs to match original image size
75
  h, w = image_np.shape[:2]
76
  depth_colored_hwc = cv2.resize(depth_colored_hwc, (w, h), cv2.INTER_LINEAR)
77
+ depth_gray_hwc = cv2.resize(depth_gray_hwc, (w, h), cv2.INTER_LINEAR)
78
 
79
+ # Convert to PIL images
80
+ return image, Image.fromarray(depth_colored_hwc), Image.fromarray(depth_gray_hwc), depth_raw_path
81
+
82
+
83
 
84
  # Gradio interface function with GPU support
85
  @spaces.GPU
 
114
  model = model.to(device) # 确保模型在正确的设备上
115
 
116
  if model is None:
117
+ return None, None, None, None
118
 
119
  # Process image and return output
120
+ image, depth_image, depth_gray, depth_raw = process_image(image, model, device)
121
+ return (image, depth_image), depth_gray, depth_raw
122
 
123
  # Create Gradio interface
124
  iface = gr.Interface(
125
  fn=gradio_interface,
126
  inputs=gr.Image(type="pil"), # Only image input, no mode selection
127
+ outputs = [ImageSlider(label="Depth slider", type="pil", slider_color="pink"), # Depth image out with a slider
128
+ gr.Image(type="pil", label="Gray Depth"),
129
+ gr.File(label="Raw Depth (NumPy File)")
130
+ ],
131
  title="Depth Estimation Demo",
132
  description="Upload an image to see the depth estimation results. Our model is running on GPU for faster processing.",
133
  examples=["1.jpg", "2.jpg", "4.png", "5.jpg", "6.jpg"],