lkp72 commited on
Commit
5154d0b
·
verified ·
1 Parent(s): ee88430

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -88
app.py CHANGED
@@ -1,94 +1,63 @@
1
- import torch
2
  import gradio as gr
3
  import numpy as np
4
- from torchvision.ops import nms
5
- from PIL import Image
6
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Load the model
9
- model = torch.jit.load("best.torchscript")
10
- model.eval()
11
-
12
- # Define the detection function
13
- def detect_taxi_plate(image):
14
- try:
15
- # Preprocess the image
16
- image_resized = Image.fromarray(image).resize((640, 640))
17
- input_tensor = torch.from_numpy(np.array(image_resized).transpose(2, 0, 1) / 255.0).unsqueeze(0).float()
18
-
19
- # Run inference
20
- output = model(input_tensor)
21
- detection_data = output[0][0].detach().numpy() # Remove batch dimension
22
-
23
- # Filter detections by confidence threshold
24
- filtered_detections = detection_data[detection_data[:, 4] >= 0.5]
25
- # Define class names
26
- class_names = ["plate", "taxi"]
27
-
28
- # Prepare boxes for NMS
29
- boxes = []
30
- confidences = []
31
- labels = []
32
- for detection in filtered_detections:
33
- if len(detection) < 7: # Ensure detection has enough elements
34
- continue
35
- x_center, y_center, width, height = detection[:4]
36
- confidence = detection[4]
37
- print(confidence)
38
- class_probs = detection[5:] # Probabilities for all classes
39
-
40
- # Get the predicted class by finding the max probability index
41
- class_index = np.argmax(class_probs)
42
- class_label = class_names[class_index]
43
- print(class_label)
44
-
45
- x_min = int(x_center - width / 2.2)
46
- y_min = int(y_center - height / 2.2)
47
- x_max = int(x_center + width / 2.2)
48
- y_max = int(y_center + height / 2.2)
49
-
50
- boxes.append([x_min, y_min, x_max, y_max])
51
- confidences.append(confidence)
52
- labels.append(class_label)
53
-
54
- if not boxes: # No valid boxes
55
- raise ValueError("No detections.")
56
-
57
- boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
58
- scores_tensor = torch.tensor(confidences, dtype=torch.float32)
59
-
60
- # Apply NMS
61
- iou_threshold = 0.5
62
- nms_indices = nms(boxes_tensor, scores_tensor, iou_threshold)
63
- nms_boxes = boxes_tensor[nms_indices].tolist()
64
- nms_labels = [labels[i] for i in nms_indices]
65
-
66
- # Draw bounding boxes
67
- image_with_boxes = image.copy()
68
- for i, box in enumerate(nms_boxes):
69
- x_min, y_min, x_max, y_max = map(int, box)
70
- label = nms_labels[i]
71
- cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
72
- cv2.putText(image_with_boxes, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
73
-
74
- return image_with_boxes
75
-
76
- except Exception as e:
77
- print(f"Error: {str(e)}")
78
- # Return error as text overlay on the image
79
- image_with_error = image.copy()
80
- cv2.putText(image_with_error, f"Error: {str(e)}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
81
- return image_with_error
82
-
83
- # Define the Gradio interface
84
- interface = gr.Interface(
85
- fn=detect_taxi_plate,
86
- inputs=gr.Image(type="numpy", label="Upload Image"),
87
- outputs=gr.Image(type="numpy", label="Output Image"),
88
- title="ITI107 Assignment: Taxi & License Plate Detection",
89
- description="Admin Number: 4744695Y\n\nUpload an image to detect if a Taxi and/or License Plate is present."
90
  )
91
 
92
- # Launch the app
93
- if __name__ == "__main__":
94
- interface.launch(share=True)
 
 
 
1
  import gradio as gr
2
  import numpy as np
 
 
3
  import cv2
4
+ import os
5
+ from ultralytics import YOLO
6
+
7
+ # Load the YOLO model
8
+ model = YOLO('best.pt')
9
+
10
+ # Function for image processing
11
+ def show_preds_image(image_path):
12
+ image = cv2.imread(image_path)
13
+ results = model.predict(source=image_path)
14
+ annotated_image = results[0].plot()
15
+ return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
16
+
17
+ # Function for video processing
18
+ def show_preds_video(video_path):
19
+ cap = cv2.VideoCapture(video_path)
20
+ out_frames = []
21
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
22
+ while cap.isOpened():
23
+ ret, frame = cap.read()
24
+ if not ret:
25
+ break
26
+ results = model.predict(source=frame)
27
+ annotated_frame = results[0].plot()
28
+ out_frames.append(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
29
+ cap.release()
30
+
31
+ # Save the annotated video
32
+ output_path = "annotated_video.mp4"
33
+ height, width, _ = out_frames[0].shape
34
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
35
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
36
+ for frame in out_frames:
37
+ writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
38
+ writer.release()
39
+ return output_path
40
+
41
+ # Gradio interfaces
42
+ inputs_image = gr.Image(type="filepath", label="Input Image")
43
+ outputs_image = gr.Image(type="numpy", label="Output Image")
44
+ interface_image = gr.Interface(
45
+ fn=show_preds_image,
46
+ inputs=inputs_image,
47
+ outputs=outputs_image,
48
+ title="Taxi & License Plate Detection with Image"
49
+ )
50
 
51
+ inputs_video = gr.Video(label="Input Video")
52
+ outputs_video = gr.Video(label="Annotated Output")
53
+ interface_video = gr.Interface(
54
+ fn=show_preds_video,
55
+ inputs=inputs_video,
56
+ outputs=outputs_video,
57
+ title="Taxi & License Plate Detection with Video"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  )
59
 
60
+ gr.TabbedInterface(
61
+ [interface_image, interface_video],
62
+ tab_names=['Image Inference', 'Video Inference']
63
+ ).launch(share=True)