File size: 2,655 Bytes
4a3d4e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716fb32
4a3d4e5
5901ece
 
4a3d4e5
 
 
5901ece
 
4a3d4e5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import os
import torch
from model.detector import *
from model.backbone import *
from model.data import Therin
import datetime

from model.detector.fasterRCNN import FasterRCNN
from model.backbone.densenet import DenseNet
from model.utils.engine import *
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision import transforms as T
from PIL import Image, ImageDraw

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def label_to_text_en(l):
    d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"}
    return d[l]

def label_to_text_ja(l):
    d = {0: "ใ—ใฎใณใ“ใ‚“ใงใ„ใ‚‹", 1: "้€™ใฃใฆใ„ใ‚‹", 2: "ใ‹ใŒใ‚“ใงใ„ใ‚‹", 3: "ใ‚ˆใ˜็™ปใฃใฆใ„ใ‚‹", 4: "ใใฎไป–"}
    return d[l]

def show_bb(img, x, y, w, h, text, textcolor, bbcolor):
    draw = ImageDraw.Draw(img)
    text_w, text_h = draw.textsize(text)
    label_y = y if y <= text_h else y - text_h
    draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor)
    draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor)
    draw.text((x, label_y), text, fill=textcolor)

def postprocess(true_image, o):
    copy_im = true_image.copy()
    data = o[0]
    boxes = data["boxes"]
    labels = data["labels"].tolist()
    scores = data["scores"].tolist()

    selected_labels = []
    selected_scores = []
    selected_indices = []
    thresh = 0.30
    for i, box in enumerate(boxes.tolist()):
        # if scores[i] > thresh:

        if i == scores.index(max(scores)):
            show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) #xywh
            
            selected_labels.append(label_to_text_ja(labels[i]))
            selected_scores.append( '{:.3f}'.format(scores[i]))
            selected_indices.append(i)
    copy_im.show()
    copy_im.save("img/detected.png")
    return selected_labels, selected_scores, selected_indices


def inference(image_pil):
    num_classes = 5
    backbone = resnet_fpn_backbone('resnet18', False)
    model = FasterRCNN(backbone, num_classes)
    state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth',map_location=device)
    model.load_state_dict(state_dict["model"])
    
    model.eval()
    _transform = T.Compose([T.ToTensor()])
    image = image_pil.convert("RGB")
    image = _transform(image)
    with torch.no_grad():
        output = model([image])
    res = postprocess(image_pil, output)
    return output, res