diallomama commited on
Commit
cef1466
·
1 Parent(s): 0d3952c
Files changed (3) hide show
  1. app.py +38 -29
  2. pytorch_model.bin +3 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -2,6 +2,12 @@ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  from torch.utils.data import Dataset
 
 
 
 
 
 
5
 
6
  class CNN(nn.Module):
7
  def __init__(self):
@@ -19,7 +25,9 @@ class CNN(nn.Module):
19
  self.relu3 = nn.ReLU()
20
  self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
21
 
22
- self.fc1 = nn.Linear(in_features=262144, out_features=512)
 
 
23
  self.relu4 = nn.ReLU()
24
 
25
  self.fc2 = nn.Linear(in_features=512, out_features=2)
@@ -46,42 +54,39 @@ class CNN(nn.Module):
46
  x = self.fc2(x)
47
 
48
  return x
49
-
50
- # Convert dataset to PyTorch dataset
51
- class AiornotDataset(Dataset):
52
- def __init__(self, image, transform=None):
53
- self.image = image
54
- self.transform = transform
55
-
56
- def __getitem__(self, idx):
57
- # Load image
58
- #img_byte = BytesIO(self.dataset[idx]['image'].tobytes())
59
- #img = self.dataset[idx]['image']
60
- # Apply transform
61
- if self.transform:
62
- img = self.transform(img)
63
-
64
- # Load label
65
- #label = self.dataset[idx]['label']
66
-
67
- return img
68
-
69
 
70
  model = CNN()
71
- model.load_state_dict(torch.load('./best_model.nn'))
 
 
72
  model.eval()
73
 
74
- def predict(image, model):
75
- img = AiornotDataset(image)
 
76
 
 
77
  with torch.no_grad():
78
  pred = model(img)
79
 
80
- is_ai = torch.max(pred.data, 0)[1]
81
- #probabilities = model(img).softmax(-1)[0,1].item()
82
- if is_ai == 1:
83
- return "The input image is generated by an AI"
84
- return "The input image is not generated by an AI"
85
 
86
  """
87
  gr.Interface.load(
@@ -90,9 +95,13 @@ gr.Interface.load(
90
  outputs = "text"
91
  ).launch()
92
  """
 
93
  gr.Interface(
94
  predict,
95
  inputs = gr.Image(label="Uploat an image", type="filepath"),
96
  #outputs = gr.outputs.Label(num_top_classes=2)
97
  outputs = "text"
98
  ).launch()
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  from torch.utils.data import Dataset
5
+ import torchvision
6
+
7
+ from torchvision import transforms
8
+
9
+ #from torchvision import transforms
10
+ from PIL import Image
11
 
12
  class CNN(nn.Module):
13
  def __init__(self):
 
25
  self.relu3 = nn.ReLU()
26
  self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
27
 
28
+ #self.fc1 = nn.Linear(in_features=262144, out_features=512)
29
+ #self.fc1 = nn.Linear(in_features=4096, out_features=512) # hr_pytorch_model.py
30
+ self.fc1 = nn.Linear(in_features=784, out_features=512)
31
  self.relu4 = nn.ReLU()
32
 
33
  self.fc2 = nn.Linear(in_features=512, out_features=2)
 
54
  x = self.fc2(x)
55
 
56
  return x
57
+ """
58
+ transform = transforms.Compose(
59
+ [transforms.Pad(2),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize((0.5,), (0.5,))])
62
+ """
63
+ # other transform
64
+ transform = transforms.Compose([
65
+ transforms.Resize(256),
66
+ transforms.CenterCrop(224),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
69
+ ])
 
 
 
 
 
 
 
70
 
71
  model = CNN()
72
+ #model.load_state_dict(torch.load('./best_model.nn'))
73
+ state_dict = torch.load('./pytorch_model.bin', map_location=torch.device('cpu'))
74
+ model.load_state_dict(state_dict, strict=False)
75
  model.eval()
76
 
77
+ def predict(image):
78
+ img = Image.open(image)
79
+ img = transform(img)
80
 
81
+ print("===============", img.shape)
82
  with torch.no_grad():
83
  pred = model(img)
84
 
85
+ #is_ai = torch.max(pred.data, 0)[1]
86
+ #print("===============", is_ai)
87
+ probabilities = model(img).softmax(-1)[0,1].item()
88
+ print("=============== prob", probabilities)
89
+ return "AI" if probabilities > 0.3 else "Not AI"
90
 
91
  """
92
  gr.Interface.load(
 
95
  outputs = "text"
96
  ).launch()
97
  """
98
+
99
  gr.Interface(
100
  predict,
101
  inputs = gr.Image(label="Uploat an image", type="filepath"),
102
  #outputs = gr.outputs.Label(num_top_classes=2)
103
  outputs = "text"
104
  ).launch()
105
+ """
106
+ gr.Interface(predict, inputs=gr.inputs.Image(shape=(512,512,3)), outputs="text").launch()
107
+ """
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54a04776d4bf8dda8b4beebf6019f30567bee86c9a1e89b5f04e81f8e5a58392
3
+ size 233468565
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- torch
 
 
 
1
+ torch
2
+ gradio
3
+ torchvision