Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import torch | |
from model.detector import * | |
from model.backbone import * | |
from model.data import Therin | |
import datetime | |
from model.detector.fasterRCNN import FasterRCNN | |
from model.backbone.densenet import DenseNet | |
from model.utils.engine import * | |
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights | |
from torchvision import transforms as T | |
from PIL import Image, ImageDraw | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def label_to_text_en(l): | |
d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"} | |
return d[l] | |
def label_to_text_ja(l): | |
d = {0: "ใใฎใณใใใงใใ", 1: "้ใฃใฆใใ", 2: "ใใใใงใใ", 3: "ใใ็ปใฃใฆใใ", 4: "ใใฎไป"} | |
return d[l] | |
def show_bb(img, x, y, w, h, text, textcolor, bbcolor): | |
draw = ImageDraw.Draw(img) | |
text_w, text_h = draw.textsize(text) | |
label_y = y if y <= text_h else y - text_h | |
draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor) | |
draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor) | |
draw.text((x, label_y), text, fill=textcolor) | |
def postprocess(true_image, o): | |
copy_im = true_image.copy() | |
data = o[0] | |
boxes = data["boxes"] | |
labels = data["labels"].tolist() | |
scores = data["scores"].tolist() | |
selected_labels = [] | |
selected_scores = [] | |
selected_indices = [] | |
thresh = 0.30 | |
for i, box in enumerate(boxes.tolist()): | |
# if scores[i] > thresh: | |
if i == scores.index(max(scores)): | |
show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) #xywh | |
selected_labels.append(label_to_text_ja(labels[i])) | |
selected_scores.append( '{:.3f}'.format(scores[i])) | |
selected_indices.append(i) | |
copy_im.show() | |
copy_im.save("img/detected.png") | |
return selected_labels, selected_scores, selected_indices | |
def inference(image_pil): | |
num_classes = 5 | |
backbone = resnet_fpn_backbone('resnet18', False) | |
model = FasterRCNN(backbone, num_classes) | |
state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth',map_location=device) | |
model.load_state_dict(state_dict["model"]) | |
model.eval() | |
_transform = T.Compose([T.ToTensor()]) | |
image = image_pil.convert("RGB") | |
image = _transform(image) | |
with torch.no_grad(): | |
output = model([image]) | |
res = postprocess(image_pil, output) | |
return output, res | |