jwlarocque commited on
Commit
51352c6
·
1 Parent(s): 306c0aa

Use HF Hub for model downloads

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  os.system("pip install ./MultiScaleDeformableAttention-1.0-py3-none-any.whl")
3
 
4
  import gradio as gr
 
5
  import numpy as np
6
  import numpy as np
7
  import torch
@@ -61,8 +62,8 @@ def show_box(box, ax):
61
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
62
 
63
 
64
-
65
- sam_checkpoint = r"~/.cache/huggingface/hub/andzhang01/segment_anything/sam_vit_l_0b3195.pth"
66
  model_type = "vit_l"
67
 
68
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint, device=device)
@@ -180,8 +181,11 @@ def predict_one(net, image, mask, box, transforms, hypar, device):
180
 
181
  hypar = {} # paramters for inferencing
182
 
183
- hypar["model_path"] ="~/.cache/huggingface/hub/jwlarocque/DIS-SAM"
184
- hypar["restore_model"] = "DIS-SAM-checkpoint.pth"
 
 
 
185
  hypar["model_digit"] = "full"
186
  hypar["input_size"] = [1024, 1024]
187
  hypar["model"] = ISNetDIS(in_ch=5)
 
2
  os.system("pip install ./MultiScaleDeformableAttention-1.0-py3-none-any.whl")
3
 
4
  import gradio as gr
5
+ from huggingface_hub import hf_hub_download
6
  import numpy as np
7
  import numpy as np
8
  import torch
 
62
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
63
 
64
 
65
+ sam_checkpoint = hf_hub_download(repo_id="andzhang01/segment-anything", filename="sam_vit_l_0b3195.pth")
66
+ # sam_checkpoint = r"~/.cache/huggingface/hub/models--andzhang01--segment-anything/sam_vit_l_0b3195.pth"
67
  model_type = "vit_l"
68
 
69
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint, device=device)
 
181
 
182
  hypar = {} # paramters for inferencing
183
 
184
+ dis_model_path = hf_hub_download(repo_id="jwlarocque/DIS-SAM", filename="DIS-SAM-checkpoint.pth")
185
+ # hypar["model_path"] ="~/.cache/huggingface/hub/jwlarocque/DIS-SAM"
186
+ hypar["model_path"] = os.path.split(dis_model_path)[0]
187
+ # hypar["restore_model"] = "DIS-SAM-checkpoint.pth"
188
+ hypar["restore_model"] = os.path.split(dis_model_path)[1]
189
  hypar["model_digit"] = "full"
190
  hypar["input_size"] = [1024, 1024]
191
  hypar["model"] = ISNetDIS(in_ch=5)