henry000 commited on
Commit
f1bfd74
Β·
1 Parent(s): 054a1c7

πŸ› [Fix] #56 bugs, the using converter in ipynb

Browse files
examples/notebook_TensorRT.ipynb CHANGED
@@ -18,7 +18,15 @@
18
  "project_root = Path().resolve().parent\n",
19
  "sys.path.append(str(project_root))\n",
20
  "\n",
21
- "from yolo import AugmentationComposer, bbox_nms, create_model, custom_logger, draw_bboxes, Vec2Box\n",
 
 
 
 
 
 
 
 
22
  "from yolo.config.config import NMSConfig"
23
  ]
24
  },
@@ -49,6 +57,8 @@
49
  "metadata": {},
50
  "outputs": [],
51
  "source": [
 
 
52
  "if os.path.exists(TRT_WEIGHT_PATH):\n",
53
  " from torch2trt import TRTModule\n",
54
  "\n",
@@ -57,8 +67,6 @@
57
  "else:\n",
58
  " from torch2trt import torch2trt\n",
59
  "\n",
60
- " with open(MODEL_CONFIG) as stream:\n",
61
- " cfg_model = OmegaConf.load(stream)\n",
62
  "\n",
63
  " model = create_model(cfg_model, weight_path=WEIGHT_PATH)\n",
64
  " model = model.to(device).eval()\n",
@@ -70,7 +78,7 @@
70
  " logger.info(f\"πŸ“₯ TensorRT model saved to oonx.pt\")\n",
71
  "\n",
72
  "transform = AugmentationComposer([], IMAGE_SIZE)\n",
73
- "vec2box = Vec2Box(model_trt, IMAGE_SIZE, device)\n"
74
  ]
75
  },
76
  {
@@ -79,7 +87,7 @@
79
  "metadata": {},
80
  "outputs": [],
81
  "source": [
82
- "image, bbox = transform(image, torch.zeros(0, 5))\n",
83
  "image = image.to(device)[None]"
84
  ]
85
  },
@@ -91,7 +99,7 @@
91
  "source": [
92
  "with torch.no_grad():\n",
93
  " predict = model_trt(image)\n",
94
- " predict = vec2box(predict[\"Main\"])\n",
95
  "predict_box = bbox_nms(predict[0], predict[2], NMSConfig(0.5, 0.5))\n",
96
  "draw_bboxes(image, predict_box)"
97
  ]
@@ -122,7 +130,7 @@
122
  "name": "python",
123
  "nbconvert_exporter": "python",
124
  "pygments_lexer": "ipython3",
125
- "version": "3.1.undefined"
126
  }
127
  },
128
  "nbformat": 4,
 
18
  "project_root = Path().resolve().parent\n",
19
  "sys.path.append(str(project_root))\n",
20
  "\n",
21
+ "from yolo import (\n",
22
+ " AugmentationComposer, \n",
23
+ " bbox_nms, \n",
24
+ " create_model, \n",
25
+ " custom_logger, \n",
26
+ " create_converter,\n",
27
+ " draw_bboxes, \n",
28
+ " Vec2Box\n",
29
+ ")\n",
30
  "from yolo.config.config import NMSConfig"
31
  ]
32
  },
 
57
  "metadata": {},
58
  "outputs": [],
59
  "source": [
60
+ "with open(MODEL_CONFIG) as stream:\n",
61
+ " cfg_model = OmegaConf.load(stream)\n",
62
  "if os.path.exists(TRT_WEIGHT_PATH):\n",
63
  " from torch2trt import TRTModule\n",
64
  "\n",
 
67
  "else:\n",
68
  " from torch2trt import torch2trt\n",
69
  "\n",
 
 
70
  "\n",
71
  " model = create_model(cfg_model, weight_path=WEIGHT_PATH)\n",
72
  " model = model.to(device).eval()\n",
 
78
  " logger.info(f\"πŸ“₯ TensorRT model saved to oonx.pt\")\n",
79
  "\n",
80
  "transform = AugmentationComposer([], IMAGE_SIZE)\n",
81
+ "converter = create_converter(cfg_model.name, model_trt, cfg_model.anchor, IMAGE_SIZE, device)\n"
82
  ]
83
  },
84
  {
 
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
90
+ "image, bbox, rev_tensor = transform(image, torch.zeros(0, 5))\n",
91
  "image = image.to(device)[None]"
92
  ]
93
  },
 
99
  "source": [
100
  "with torch.no_grad():\n",
101
  " predict = model_trt(image)\n",
102
+ " predict = converter(predict[\"Main\"])\n",
103
  "predict_box = bbox_nms(predict[0], predict[2], NMSConfig(0.5, 0.5))\n",
104
  "draw_bboxes(image, predict_box)"
105
  ]
 
130
  "name": "python",
131
  "nbconvert_exporter": "python",
132
  "pygments_lexer": "ipython3",
133
+ "version": "3.10.14"
134
  }
135
  },
136
  "nbformat": 4,
examples/notebook_inference.ipynb CHANGED
@@ -1,5 +1,15 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": null,
@@ -35,7 +45,7 @@
35
  "source": [
36
  "CONFIG_PATH = \"../yolo/config\"\n",
37
  "CONFIG_NAME = \"config\"\n",
38
- "MODEL = \"v7-base\"\n",
39
  "\n",
40
  "DEVICE = 'cuda:0'\n",
41
  "CLASS_NUM = 80\n",
@@ -54,7 +64,9 @@
54
  "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
55
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
56
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
 
57
  " transform = AugmentationComposer([], cfg.image_size)\n",
 
58
  " converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)\n",
59
  " post_proccess = PostProccess(converter, cfg.task.nms)"
60
  ]
@@ -81,7 +93,7 @@
81
  " predict = model(image)\n",
82
  " pred_bbox = post_proccess(predict, rev_tensor)\n",
83
  "\n",
84
- "draw_bboxes(pil_image, pred_bbox, idx2label=cfg.class_list)"
85
  ]
86
  },
87
  {
@@ -92,6 +104,11 @@
92
  "\n",
93
  "![image](../demo/images/output/visualize.png)"
94
  ]
 
 
 
 
 
95
  }
96
  ],
97
  "metadata": {
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%load_ext autoreload\n",
10
+ "%autoreload 2"
11
+ ]
12
+ },
13
  {
14
  "cell_type": "code",
15
  "execution_count": null,
 
45
  "source": [
46
  "CONFIG_PATH = \"../yolo/config\"\n",
47
  "CONFIG_NAME = \"config\"\n",
48
+ "MODEL = \"v9-c\"\n",
49
  "\n",
50
  "DEVICE = 'cuda:0'\n",
51
  "CLASS_NUM = 80\n",
 
64
  "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
65
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
66
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
67
+ "\n",
68
  " transform = AugmentationComposer([], cfg.image_size)\n",
69
+ "\n",
70
  " converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)\n",
71
  " post_proccess = PostProccess(converter, cfg.task.nms)"
72
  ]
 
93
  " predict = model(image)\n",
94
  " pred_bbox = post_proccess(predict, rev_tensor)\n",
95
  "\n",
96
+ "draw_bboxes(pil_image, pred_bbox, idx2label=cfg.dataset.class_list)"
97
  ]
98
  },
99
  {
 
104
  "\n",
105
  "![image](../demo/images/output/visualize.png)"
106
  ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": []
112
  }
113
  ],
114
  "metadata": {
examples/notebook_smallobject.ipynb CHANGED
@@ -30,7 +30,17 @@
30
  "project_root = Path().resolve().parent\n",
31
  "sys.path.append(str(project_root))\n",
32
  "\n",
33
- "from yolo import AugmentationComposer, bbox_nms, Config, create_model, custom_logger, draw_bboxes, Vec2Box, NMSConfig, PostProccess"
 
 
 
 
 
 
 
 
 
 
34
  ]
35
  },
36
  {
@@ -62,8 +72,8 @@
62
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
63
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
64
  " transform = AugmentationComposer([], cfg.image_size)\n",
65
- " vec2box = Vec2Box(model, cfg.image_size, device)\n",
66
- " post_proccess = PostProccess(vec2box, NMSConfig(0.5, 0.9))\n",
67
  " "
68
  ]
69
  },
@@ -112,7 +122,7 @@
112
  "with torch.no_grad():\n",
113
  " total_image, total_shift = slide_image(image)\n",
114
  " predict = model(total_image)\n",
115
- " pred_class, _, pred_bbox = vec2box(predict[\"Main\"])\n",
116
  "pred_bbox[1:] = (pred_bbox[1: ] + total_shift[:, None]) / SLIDE\n",
117
  "pred_bbox = pred_bbox.view(1, -1, 4)\n",
118
  "pred_class = pred_class.view(1, -1, 80)\n",
@@ -126,7 +136,7 @@
126
  "metadata": {},
127
  "outputs": [],
128
  "source": [
129
- "draw_bboxes(pil_image, predict_box, idx2label=cfg.class_list)"
130
  ]
131
  },
132
  {
 
30
  "project_root = Path().resolve().parent\n",
31
  "sys.path.append(str(project_root))\n",
32
  "\n",
33
+ "from yolo import (\n",
34
+ " AugmentationComposer, \n",
35
+ " Config, \n",
36
+ " NMSConfig, \n",
37
+ " PostProccess,\n",
38
+ " bbox_nms, \n",
39
+ " create_model, \n",
40
+ " create_converter, \n",
41
+ " custom_logger, \n",
42
+ " draw_bboxes, \n",
43
+ ")"
44
  ]
45
  },
46
  {
 
72
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
73
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
74
  " transform = AugmentationComposer([], cfg.image_size)\n",
75
+ " converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)\n",
76
+ " post_proccess = PostProccess(converter, NMSConfig(0.5, 0.9))\n",
77
  " "
78
  ]
79
  },
 
122
  "with torch.no_grad():\n",
123
  " total_image, total_shift = slide_image(image)\n",
124
  " predict = model(total_image)\n",
125
+ " pred_class, _, pred_bbox = converter(predict[\"Main\"])\n",
126
  "pred_bbox[1:] = (pred_bbox[1: ] + total_shift[:, None]) / SLIDE\n",
127
  "pred_bbox = pred_bbox.view(1, -1, 4)\n",
128
  "pred_class = pred_class.view(1, -1, 80)\n",
 
136
  "metadata": {},
137
  "outputs": [],
138
  "source": [
139
+ "draw_bboxes(pil_image, predict_box, idx2label=cfg.dataset.class_list)"
140
  ]
141
  },
142
  {
examples/sample_inference.py CHANGED
@@ -2,29 +2,39 @@ import sys
2
  from pathlib import Path
3
 
4
  import hydra
5
- import torch
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
- from yolo.config.config import Config
11
- from yolo.model.yolo import create_model
12
- from yolo.tools.data_loader import create_dataloader
13
- from yolo.tools.solver import ModelTester
14
- from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
18
- def main(cfg: Config):
19
- custom_logger()
20
- save_path = validate_log_directory(cfg, cfg.name)
21
- dataloader = create_dataloader(cfg)
22
-
23
- device = torch.device(cfg.device)
24
- model = create_model(cfg).to(device)
25
 
26
- tester = ModelTester(cfg, model, save_path, device)
27
- tester.solve(dataloader)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  if __name__ == "__main__":
 
2
  from pathlib import Path
3
 
4
  import hydra
 
5
 
6
  project_root = Path(__file__).resolve().parent.parent
7
  sys.path.append(str(project_root))
8
 
 
 
 
 
 
9
 
10
+ from yolo import (
11
+ Config,
12
+ FastModelLoader,
13
+ ModelTester,
14
+ ProgressLogger,
15
+ create_converter,
16
+ create_dataloader,
17
+ create_model,
18
+ )
19
+ from yolo.utils.model_utils import get_device
20
 
 
 
 
 
 
 
 
 
21
 
22
+ @hydra.main(config_path="config", config_name="config", version_base=None)
23
+ def main(cfg: Config):
24
+ progress = ProgressLogger(cfg, exp_name=cfg.name)
25
+ device, use_ddp = get_device(cfg.device)
26
+ dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
27
+ if getattr(cfg.task, "fast_inference", False):
28
+ model = FastModelLoader(cfg).load_model(device)
29
+ else:
30
+ model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
31
+ model = model.to(device)
32
+
33
+ converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
34
+
35
+ solver = ModelTester(cfg, model, converter, progress, device)
36
+ progress.start()
37
+ solver.solve(dataloader)
38
 
39
 
40
  if __name__ == "__main__":
examples/sample_train.py CHANGED
@@ -2,29 +2,35 @@ import sys
2
  from pathlib import Path
3
 
4
  import hydra
5
- import torch
6
 
7
  project_root = Path(__file__).resolve().parent.parent
8
  sys.path.append(str(project_root))
9
 
10
- from yolo.config.config import Config
11
- from yolo.model.yolo import create_model
12
- from yolo.tools.data_loader import create_dataloader
13
- from yolo.tools.solver import ModelTrainer
14
- from yolo.utils.logging_utils import custom_logger, validate_log_directory
15
 
 
 
 
 
 
 
 
 
 
16
 
17
- @hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
 
18
  def main(cfg: Config):
19
- custom_logger()
20
- save_path = validate_log_directory(cfg, cfg.name)
21
- dataloader = create_dataloader(cfg)
22
- # TODO: get_device or rank, for DDP mode
23
- device = torch.device(cfg.device)
24
- model = create_model(cfg).to(device)
25
-
26
- trainer = ModelTrainer(cfg, model, save_path, device)
27
- trainer.solve(dataloader, cfg.task.epoch)
 
 
28
 
29
 
30
  if __name__ == "__main__":
 
2
  from pathlib import Path
3
 
4
  import hydra
 
5
 
6
  project_root = Path(__file__).resolve().parent.parent
7
  sys.path.append(str(project_root))
8
 
 
 
 
 
 
9
 
10
+ from yolo import (
11
+ Config,
12
+ ModelTrainer,
13
+ ProgressLogger,
14
+ create_converter,
15
+ create_dataloader,
16
+ create_model,
17
+ )
18
+ from yolo.utils.model_utils import get_device
19
 
20
+
21
+ @hydra.main(config_path="config", config_name="config", version_base=None)
22
  def main(cfg: Config):
23
+ progress = ProgressLogger(cfg, exp_name=cfg.name)
24
+ device, use_ddp = get_device(cfg.device)
25
+ dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
26
+ model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
27
+ model = model.to(device)
28
+
29
+ converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
30
+
31
+ solver = ModelTrainer(cfg, model, converter, progress, device)
32
+ progress.start()
33
+ solver.solve(dataloader)
34
 
35
 
36
  if __name__ == "__main__":
yolo/__init__.py CHANGED
@@ -5,12 +5,13 @@ from yolo.tools.drawer import draw_bboxes
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
- from yolo.utils.logging_utils import custom_logger
9
  from yolo.utils.model_utils import PostProccess
10
 
11
  all = [
12
  "create_model",
13
  "Config",
 
14
  "NMSConfig",
15
  "custom_logger",
16
  "validate_log_directory",
 
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
  from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
+ from yolo.utils.logging_utils import ProgressLogger, custom_logger
9
  from yolo.utils.model_utils import PostProccess
10
 
11
  all = [
12
  "create_model",
13
  "Config",
14
+ "ProgressLogger",
15
  "NMSConfig",
16
  "custom_logger",
17
  "validate_log_directory",