Ubuntu commited on
Commit
4a3d4e5
ยท
1 Parent(s): 8470d62

add application file

Browse files
Files changed (4) hide show
  1. .gitignore +6 -0
  2. app.py +49 -0
  3. logo/logo_small_EQUES.png +0 -0
  4. run.py +73 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ model/dataset
3
+ model/img
4
+ model/loss
5
+ model/PIE
6
+ */__pycache__
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+
4
+ with st.sidebar:
5
+ st.image("logo/logo_small_EQUES.png")
6
+ st.button("ใ‚ซใƒกใƒฉใจๆŽฅ็ถš")
7
+
8
+ from PIL import Image
9
+ import io
10
+ uploaded_file = st.file_uploader('็”ปๅƒใ‚’ใ‚ขใƒƒใƒ—ใƒญใƒผใƒ‰')
11
+
12
+ if uploaded_file is not None:
13
+ image = Image.open(uploaded_file)
14
+ img_array = np.array(image)
15
+ st.image(
16
+ image, caption='upload images',
17
+ use_column_width=True
18
+ )
19
+
20
+ option = st.selectbox(
21
+ "ใƒ†ใ‚นใƒˆ็”ปๅƒใฎไฝฟ็”จ",
22
+ ("ใƒ†ใ‚นใƒˆ็”ปๅƒ๏ผ‘", "ใƒ†ใ‚นใƒˆ็”ปๅƒ๏ผ’", "ใƒ†ใ‚นใƒˆ็”ปๅƒ๏ผ“"),
23
+ index=None,
24
+ placeholder="ใƒ†ใ‚นใƒˆ็”ปๅƒใ‚’้ธๆŠžใ—ใฆใใ ใ•ใ„ใ€‚",
25
+ )
26
+
27
+ if st.button("ไธๅฏฉ่€…ใ‚’ๆคœ็Ÿฅ"):
28
+ if option != None:
29
+ image = Image.open(f"img/sample/{option}.JPG")
30
+ assert image != None,"็”ปๅƒใ‚’ใ‚ขใƒƒใƒ—ใƒญใƒผใƒ‰ใ—ใฆใใ ใ•ใ„ใ€‚"
31
+ from run import inference
32
+ with st.spinner("Operation in progress. Please wait."):
33
+ output, (labels, scores, indices) = inference(image)
34
+ st.image("img/detected.png")
35
+
36
+ st.title("่งฃๆž็ตๆžœ")
37
+
38
+ for a,b in zip(labels, scores):
39
+ st.write(a,b)
40
+
41
+ if len(indices) != 0:
42
+ st.warning('ไธๅฏฉ่€…ใŒๆคœ็Ÿฅใ•ใ‚ŒใŸๅฏ่ƒฝๆ€งใŒใ‚ใ‚Šใพใ™', icon="โš ๏ธ")
43
+ else:
44
+ st.info('ไธๅฏฉ่€…ใฏๆคœ็Ÿฅใ•ใ‚Œใพใ›ใ‚“ใงใ—ใŸ', icon="โ„น๏ธ")
45
+
46
+
47
+ st.title("่งฃๆž็ตๆžœ่ฉณ็ดฐ")
48
+ st.write(output)
49
+
logo/logo_small_EQUES.png ADDED
run.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from model.detector import *
5
+ from model.backbone import *
6
+ from model.loss import *
7
+ from model.data import Therin
8
+ import datetime
9
+
10
+ from model.detector.fasterRCNN import FasterRCNN
11
+ from model.backbone.densenet import DenseNet
12
+ from model.utils.engine import *
13
+ from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor
14
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights
15
+ from torchvision import transforms as T
16
+ from PIL import Image, ImageDraw
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ def label_to_text_en(l):
21
+ d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"}
22
+ return d[l]
23
+
24
+ def label_to_text_ja(l):
25
+ d = {0: "ใ—ใฎใณใ“ใ‚“ใงใ„ใ‚‹", 1: "้€™ใฃใฆใ„ใ‚‹", 2: "ใ‹ใŒใ‚“ใงใ„ใ‚‹", 3: "ใ‚ˆใ˜็™ปใฃใฆใ„ใ‚‹", 4: "ใใฎไป–"}
26
+ return d[l]
27
+
28
+ def show_bb(img, x, y, w, h, text, textcolor, bbcolor):
29
+ draw = ImageDraw.Draw(img)
30
+ text_w, text_h = draw.textsize(text)
31
+ label_y = y if y <= text_h else y - text_h
32
+ draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor)
33
+ draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor)
34
+ draw.text((x, label_y), text, fill=textcolor)
35
+
36
+ def postprocess(true_image, o):
37
+ copy_im = true_image.copy()
38
+ data = o[0]
39
+ boxes = data["boxes"]
40
+ labels = data["labels"].tolist()
41
+ scores = data["scores"].tolist()
42
+
43
+ selected_labels = []
44
+ selected_scores = []
45
+ selected_indices = []
46
+ thresh = 0.30
47
+ for i, box in enumerate(boxes.tolist()):
48
+ # if scores[i] > thresh:
49
+
50
+ if i == scores.index(max(scores)):
51
+ show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) #xywh
52
+
53
+ selected_labels.append(label_to_text_ja(labels[i]))
54
+ selected_scores.append( '{:.3f}'.format(scores[i]))
55
+ selected_indices.append(i)
56
+ copy_im.show()
57
+ copy_im.save("img/detected.png")
58
+ return selected_labels, selected_scores, selected_indices
59
+
60
+
61
+ def inference(image_pil):
62
+ num_classes = 5
63
+ backbone = resnet_fpn_backbone('resnet18', False)
64
+ model = FasterRCNN(backbone, num_classes)
65
+ model.eval()
66
+ state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth')
67
+ model.load_state_dict(state_dict["model"])
68
+ _transform = T.Compose([T.ToTensor()])
69
+ image = image_pil.convert("RGB")
70
+ image = _transform(image)
71
+ output = model([image])
72
+ res = postprocess(image_pil, output)
73
+ return output, res