henry000 commited on
Commit
5f0e785
·
1 Parent(s): da4f0bf

👽️ [Update] HF_Demo due to new converter

Browse files
Files changed (1) hide show
  1. demo/hf_demo.py +11 -11
demo/hf_demo.py CHANGED
@@ -10,7 +10,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent))
10
  from yolo import (
11
  AugmentationComposer,
12
  NMSConfig,
13
- PostProccess,
14
  create_converter,
15
  create_model,
16
  draw_bboxes,
@@ -20,27 +20,26 @@ DEFAULT_MODEL = "v9-c"
20
  IMAGE_SIZE = (640, 640)
21
 
22
 
23
- def load_model(model_name, device):
24
  model_cfg = OmegaConf.load(f"yolo/config/model/{model_name}.yaml")
25
  model_cfg.model.auxiliary = {}
26
  model = create_model(model_cfg, True)
27
- model.to(device).eval()
28
- return model, model_cfg
 
29
 
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- model, model_cfg = load_model(DEFAULT_MODEL, device)
33
- converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
34
  class_list = OmegaConf.load("yolo/config/dataset/coco.yaml").class_list
35
 
36
  transform = AugmentationComposer([])
37
 
38
 
39
- def predict(model_name, image, nms_confidence, nms_iou):
40
  global DEFAULT_MODEL, model, device, converter, class_list, post_proccess
41
  if model_name != DEFAULT_MODEL:
42
- model, model_cfg = load_model(model_name, device)
43
- converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
44
  DEFAULT_MODEL = model_name
45
 
46
  image_tensor, _, rev_tensor = transform(image)
@@ -48,8 +47,8 @@ def predict(model_name, image, nms_confidence, nms_iou):
48
  image_tensor = image_tensor.to(device)[None]
49
  rev_tensor = rev_tensor.to(device)[None]
50
 
51
- nms_config = NMSConfig(nms_confidence, nms_iou)
52
- post_proccess = PostProccess(converter, nms_config)
53
 
54
  with torch.no_grad():
55
  predict = model(image_tensor)
@@ -67,6 +66,7 @@ interface = gradio.Interface(
67
  gradio.components.Image(type="pil", label="Input Image"),
68
  gradio.components.Slider(0, 1, step=0.01, value=0.5, label="NMS Confidence Threshold"),
69
  gradio.components.Slider(0, 1, step=0.01, value=0.5, label="NMS IoU Threshold"),
 
70
  ],
71
  outputs=gradio.components.Image(type="pil", label="Output Image"),
72
  )
 
10
  from yolo import (
11
  AugmentationComposer,
12
  NMSConfig,
13
+ PostProcess,
14
  create_converter,
15
  create_model,
16
  draw_bboxes,
 
20
  IMAGE_SIZE = (640, 640)
21
 
22
 
23
+ def load_model(model_name):
24
  model_cfg = OmegaConf.load(f"yolo/config/model/{model_name}.yaml")
25
  model_cfg.model.auxiliary = {}
26
  model = create_model(model_cfg, True)
27
+ converter = create_converter(model_cfg.name, model, model_cfg.anchor, IMAGE_SIZE, device)
28
+ model = model.to(device).eval()
29
+ return model, converter
30
 
31
 
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model, converter = load_model(DEFAULT_MODEL)
 
34
  class_list = OmegaConf.load("yolo/config/dataset/coco.yaml").class_list
35
 
36
  transform = AugmentationComposer([])
37
 
38
 
39
+ def predict(model_name, image, nms_confidence, nms_iou, max_bbox):
40
  global DEFAULT_MODEL, model, device, converter, class_list, post_proccess
41
  if model_name != DEFAULT_MODEL:
42
+ model, converter = load_model(model_name)
 
43
  DEFAULT_MODEL = model_name
44
 
45
  image_tensor, _, rev_tensor = transform(image)
 
47
  image_tensor = image_tensor.to(device)[None]
48
  rev_tensor = rev_tensor.to(device)[None]
49
 
50
+ nms_config = NMSConfig(nms_confidence, nms_iou, max_bbox)
51
+ post_proccess = PostProcess(converter, nms_config)
52
 
53
  with torch.no_grad():
54
  predict = model(image_tensor)
 
66
  gradio.components.Image(type="pil", label="Input Image"),
67
  gradio.components.Slider(0, 1, step=0.01, value=0.5, label="NMS Confidence Threshold"),
68
  gradio.components.Slider(0, 1, step=0.01, value=0.5, label="NMS IoU Threshold"),
69
+ gradio.components.Slider(0, 1000, step=10, value=400, label="Max Bounding Box Number"),
70
  ],
71
  outputs=gradio.components.Image(type="pil", label="Output Image"),
72
  )