Spaces:
Sleeping
Sleeping
Ubuntu
commited on
Commit
ยท
4a3d4e5
1
Parent(s):
8470d62
add application file
Browse files- .gitignore +6 -0
- app.py +49 -0
- logo/logo_small_EQUES.png +0 -0
- run.py +73 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
model/dataset
|
3 |
+
model/img
|
4 |
+
model/loss
|
5 |
+
model/PIE
|
6 |
+
*/__pycache__
|
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
with st.sidebar:
|
5 |
+
st.image("logo/logo_small_EQUES.png")
|
6 |
+
st.button("ใซใกใฉใจๆฅ็ถ")
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
import io
|
10 |
+
uploaded_file = st.file_uploader('็ปๅใใขใใใญใผใ')
|
11 |
+
|
12 |
+
if uploaded_file is not None:
|
13 |
+
image = Image.open(uploaded_file)
|
14 |
+
img_array = np.array(image)
|
15 |
+
st.image(
|
16 |
+
image, caption='upload images',
|
17 |
+
use_column_width=True
|
18 |
+
)
|
19 |
+
|
20 |
+
option = st.selectbox(
|
21 |
+
"ใในใ็ปๅใฎไฝฟ็จ",
|
22 |
+
("ใในใ็ปๅ๏ผ", "ใในใ็ปๅ๏ผ", "ใในใ็ปๅ๏ผ"),
|
23 |
+
index=None,
|
24 |
+
placeholder="ใในใ็ปๅใ้ธๆใใฆใใ ใใใ",
|
25 |
+
)
|
26 |
+
|
27 |
+
if st.button("ไธๅฏฉ่
ใๆค็ฅ"):
|
28 |
+
if option != None:
|
29 |
+
image = Image.open(f"img/sample/{option}.JPG")
|
30 |
+
assert image != None,"็ปๅใใขใใใญใผใใใฆใใ ใใใ"
|
31 |
+
from run import inference
|
32 |
+
with st.spinner("Operation in progress. Please wait."):
|
33 |
+
output, (labels, scores, indices) = inference(image)
|
34 |
+
st.image("img/detected.png")
|
35 |
+
|
36 |
+
st.title("่งฃๆ็ตๆ")
|
37 |
+
|
38 |
+
for a,b in zip(labels, scores):
|
39 |
+
st.write(a,b)
|
40 |
+
|
41 |
+
if len(indices) != 0:
|
42 |
+
st.warning('ไธๅฏฉ่
ใๆค็ฅใใใๅฏ่ฝๆงใใใใพใ', icon="โ ๏ธ")
|
43 |
+
else:
|
44 |
+
st.info('ไธๅฏฉ่
ใฏๆค็ฅใใใพใใใงใใ', icon="โน๏ธ")
|
45 |
+
|
46 |
+
|
47 |
+
st.title("่งฃๆ็ตๆ่ฉณ็ดฐ")
|
48 |
+
st.write(output)
|
49 |
+
|
logo/logo_small_EQUES.png
ADDED
![]() |
run.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from model.detector import *
|
5 |
+
from model.backbone import *
|
6 |
+
from model.loss import *
|
7 |
+
from model.data import Therin
|
8 |
+
import datetime
|
9 |
+
|
10 |
+
from model.detector.fasterRCNN import FasterRCNN
|
11 |
+
from model.backbone.densenet import DenseNet
|
12 |
+
from model.utils.engine import *
|
13 |
+
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor
|
14 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights
|
15 |
+
from torchvision import transforms as T
|
16 |
+
from PIL import Image, ImageDraw
|
17 |
+
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
|
20 |
+
def label_to_text_en(l):
|
21 |
+
d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"}
|
22 |
+
return d[l]
|
23 |
+
|
24 |
+
def label_to_text_ja(l):
|
25 |
+
d = {0: "ใใฎใณใใใงใใ", 1: "้ใฃใฆใใ", 2: "ใใใใงใใ", 3: "ใใ็ปใฃใฆใใ", 4: "ใใฎไป"}
|
26 |
+
return d[l]
|
27 |
+
|
28 |
+
def show_bb(img, x, y, w, h, text, textcolor, bbcolor):
|
29 |
+
draw = ImageDraw.Draw(img)
|
30 |
+
text_w, text_h = draw.textsize(text)
|
31 |
+
label_y = y if y <= text_h else y - text_h
|
32 |
+
draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor)
|
33 |
+
draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor)
|
34 |
+
draw.text((x, label_y), text, fill=textcolor)
|
35 |
+
|
36 |
+
def postprocess(true_image, o):
|
37 |
+
copy_im = true_image.copy()
|
38 |
+
data = o[0]
|
39 |
+
boxes = data["boxes"]
|
40 |
+
labels = data["labels"].tolist()
|
41 |
+
scores = data["scores"].tolist()
|
42 |
+
|
43 |
+
selected_labels = []
|
44 |
+
selected_scores = []
|
45 |
+
selected_indices = []
|
46 |
+
thresh = 0.30
|
47 |
+
for i, box in enumerate(boxes.tolist()):
|
48 |
+
# if scores[i] > thresh:
|
49 |
+
|
50 |
+
if i == scores.index(max(scores)):
|
51 |
+
show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) #xywh
|
52 |
+
|
53 |
+
selected_labels.append(label_to_text_ja(labels[i]))
|
54 |
+
selected_scores.append( '{:.3f}'.format(scores[i]))
|
55 |
+
selected_indices.append(i)
|
56 |
+
copy_im.show()
|
57 |
+
copy_im.save("img/detected.png")
|
58 |
+
return selected_labels, selected_scores, selected_indices
|
59 |
+
|
60 |
+
|
61 |
+
def inference(image_pil):
|
62 |
+
num_classes = 5
|
63 |
+
backbone = resnet_fpn_backbone('resnet18', False)
|
64 |
+
model = FasterRCNN(backbone, num_classes)
|
65 |
+
model.eval()
|
66 |
+
state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth')
|
67 |
+
model.load_state_dict(state_dict["model"])
|
68 |
+
_transform = T.Compose([T.ToTensor()])
|
69 |
+
image = image_pil.convert("RGB")
|
70 |
+
image = _transform(image)
|
71 |
+
output = model([image])
|
72 |
+
res = postprocess(image_pil, output)
|
73 |
+
return output, res
|