DennyW commited on
Commit
843df20
·
verified ·
1 Parent(s): a52b66f

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +53 -0
  2. model.py +47 -0
  3. requirements.txt +106 -0
  4. retinanet_best_model.pth +3 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ from model import RetinaNet # Import your RetinaNet model definition
6
+
7
+ # Define the image transformation pipeline
8
+ image_transform = transforms.Compose([
9
+ transforms.Resize((224, 224)),
10
+ transforms.ToTensor(),
11
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
12
+ ])
13
+
14
+ # Load the model
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ model = RetinaNet(num_classes=2).to(device)
17
+ model.load_state_dict(torch.load("retinanet_best_model.pth", map_location=device))
18
+ model.eval()
19
+
20
+ # Prediction function
21
+ def predict_image(image):
22
+ # Preprocess the image
23
+ img = Image.fromarray(image).convert('RGB') # Convert Gradio input to PIL Image
24
+ input_tensor = image_transform(img).unsqueeze(0).to(device)
25
+
26
+ # Perform inference
27
+ with torch.no_grad():
28
+ prediction = model(input_tensor.float())
29
+ sum_value = abs(torch.sum(prediction[0]))
30
+ p_true = abs(prediction[0][0])
31
+ p_false = abs(prediction[0][1])
32
+
33
+ # Interpret the prediction
34
+ if p_true > 0.7:
35
+ result = "Accepted"
36
+ confidence = float(p_true)
37
+ else:
38
+ result = "Rejected"
39
+ confidence = float(p_false)
40
+
41
+ return f"Result: {result}, Confidence: {confidence:.2f}"
42
+
43
+ # Create the Gradio interface
44
+ with gr.Blocks() as demo:
45
+ gr.Markdown("# RetinaNet Model Prediction")
46
+ with gr.Row():
47
+ image_input = gr.Image(label="Upload Image", type="numpy")
48
+ output_text = gr.Textbox(label="Prediction Result")
49
+ predict_button = gr.Button("Predict")
50
+ predict_button.click(predict_image, inputs=image_input, outputs=output_text)
51
+
52
+ # Launch the app
53
+ demo.launch()
model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from torchvision import transforms, datasets, models
7
+
8
+ # Define model
9
+ class RetinaNet(nn.Module):
10
+ def __init__(self, num_classes=2):
11
+ super(RetinaNet, self).__init__()
12
+ self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
13
+
14
+ # Freeze backbone parameters
15
+ for param in self.backbone.parameters():
16
+ param.requires_grad = False
17
+
18
+ # Unfreeze later layers
19
+ for param in self.backbone.layer3.parameters():
20
+ param.requires_grad = True
21
+ for param in self.backbone.layer4.parameters():
22
+ param.requires_grad = False
23
+
24
+ # Modified classifier head
25
+ self.classifier = nn.Sequential(
26
+ nn.Linear(2048, 512),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.5),
29
+ nn.Linear(512, num_classes)
30
+ # nn.Sigmoid()
31
+ )
32
+
33
+ def forward(self, x):
34
+ x = self.backbone.conv1(x)
35
+ x = self.backbone.bn1(x)
36
+ x = self.backbone.relu(x)
37
+ x = self.backbone.maxpool(x)
38
+
39
+ x = self.backbone.layer1(x)
40
+ x = self.backbone.layer2(x)
41
+ x = self.backbone.layer3(x)
42
+ x = self.backbone.layer4(x)
43
+
44
+ x = self.backbone.avgpool(x)
45
+ x = torch.flatten(x, 1)
46
+ x = self.classifier(x)
47
+ return x
requirements.txt ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiohappyeyeballs==2.4.6
3
+ aiohttp==3.11.12
4
+ aiosignal==1.3.2
5
+ astunparse==1.6.3
6
+ attrs==25.1.0
7
+ blinker==1.9.0
8
+ CacheControl==0.14.2
9
+ cachetools==5.5.1
10
+ certifi==2025.1.31
11
+ cffi==1.17.1
12
+ charset-normalizer==2.1.1
13
+ click==8.1.8
14
+ colorama==0.4.6
15
+ cryptography==44.0.1
16
+ datasets==3.2.0
17
+ dill==0.3.8
18
+ filelock==3.17.0
19
+ firebase-admin==6.6.0
20
+ Flask==3.1.0
21
+ Flask-Cors==5.0.0
22
+ flatbuffers==25.1.24
23
+ frozenlist==1.5.0
24
+ fsspec==2024.9.0
25
+ gast==0.6.0
26
+ google-api-core==2.24.1
27
+ google-api-python-client==2.161.0
28
+ google-auth==2.38.0
29
+ google-auth-httplib2==0.2.0
30
+ google-cloud-core==2.4.1
31
+ google-cloud-firestore==2.20.0
32
+ google-cloud-storage==3.0.0
33
+ google-crc32c==1.6.0
34
+ google-pasta==0.2.0
35
+ google-resumable-media==2.7.2
36
+ googleapis-common-protos==1.67.0
37
+ greenlet==3.1.1
38
+ grpcio==1.70.0
39
+ grpcio-status==1.70.0
40
+ h5py==3.12.1
41
+ httplib2==0.22.0
42
+ idna==3.10
43
+ itsdangerous==2.2.0
44
+ Jinja2==3.1.5
45
+ keras==3.8.0
46
+ libclang==18.1.1
47
+ Markdown==3.7
48
+ markdown-it-py==3.0.0
49
+ MarkupSafe==3.0.2
50
+ mdurl==0.1.2
51
+ ml-dtypes==0.4.1
52
+ mpmath==1.3.0
53
+ msgpack==1.1.0
54
+ multidict==6.1.0
55
+ multiprocess==0.70.16
56
+ namex==0.0.8
57
+ networkx==3.4.2
58
+ numpy==2.0.2
59
+ opencv-python==4.11.0.86
60
+ opt_einsum==3.4.0
61
+ optree==0.14.0
62
+ packaging==24.2
63
+ pandas==2.2.3
64
+ pillow==11.1.0
65
+ propcache==0.2.1
66
+ proto-plus==1.26.0
67
+ protobuf==5.29.3
68
+ psycopg2-binary==2.9.10
69
+ pyarrow==19.0.0
70
+ pyasn1==0.6.1
71
+ pyasn1_modules==0.4.1
72
+ pycparser==2.22
73
+ Pygments==2.19.1
74
+ PyJWT==2.10.1
75
+ pyparsing==3.2.1
76
+ python-dateutil==2.9.0.post0
77
+ pytz==2025.1
78
+ PyYAML==6.0.2
79
+ regex==2024.11.6
80
+ requests==2.32.3
81
+ rich==13.9.4
82
+ rsa==4.9
83
+ safetensors==0.5.2
84
+ setuptools==75.8.0
85
+ six==1.17.0
86
+ SQLAlchemy==2.0.38
87
+ sympy==1.13.1
88
+ tensorboard==2.18.0
89
+ tensorboard-data-server==0.7.2
90
+ tensorflow==2.18.0
91
+ tensorflow_intel==2.18.0
92
+ termcolor==2.5.0
93
+ tokenizers==0.21.0
94
+ torch==2.6.0
95
+ torchvision==0.21.0
96
+ tqdm==4.67.1
97
+ transformers==4.48.3
98
+ typing_extensions==4.12.2
99
+ tzdata==2025.1
100
+ uritemplate==4.1.1
101
+ urllib3==1.26.20
102
+ Werkzeug==3.1.3
103
+ wheel==0.45.1
104
+ wrapt==1.17.2
105
+ xxhash==3.5.0
106
+ yarl==1.18.3
retinanet_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cf5190eab09966edb71eb3cf8c67d358d37badb96ce21bd611901bdf5b8d0cc
3
+ size 106756882