update app, and requitemts

#1
Files changed (2) hide show
  1. app.py +62 -66
  2. requirements.txt +4 -9
app.py CHANGED
@@ -1,98 +1,94 @@
1
  import streamlit as st
2
  import torch
3
  from torchvision import transforms
4
- from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
- import time
7
 
8
- # Simplified YOLO-style model definition (Pillow-only version)
9
- class PlantDiseaseModel(torch.nn.Module):
10
  def __init__(self, num_classes=2):
11
  super().__init__()
12
- # Example backbone (replace with your actual model architecture)
13
- self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
14
- self.model.classes = num_classes # Set your number of classes
15
-
 
 
 
 
 
 
 
 
16
  def forward(self, x):
17
  return self.model(x)
18
 
19
- # Load model
20
  @st.cache_resource
21
  def load_model():
22
- model = PlantDiseaseModel(num_classes=2) # Update class count
23
  try:
24
  model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
25
- except:
26
- st.warning("Using pretrained weights (custom weights not found)")
 
27
  return model
28
 
29
- # Draw bounding boxes with Pillow
30
- def draw_boxes_pillow(image, predictions):
31
- """Draw boxes/labels on image using Pillow only"""
32
- draw = ImageDraw.Draw(image)
33
- try:
34
- font = ImageFont.load_default()
35
- for _, row in predictions.iterrows():
36
- xmin, ymin, xmax, ymax = row['xmin'], row['ymin'], row['xmax'], row['ymax']
37
- label = f"{row['name']} {row['confidence']:.2f}"
38
-
39
- # Draw rectangle
40
- draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3)
41
-
42
- # Draw label background
43
- text_width, text_height = font.getsize(label)
44
- draw.rectangle([xmin, ymin-text_height, xmin+text_width, ymin], fill="red")
45
-
46
- # Draw text
47
- draw.text((xmin, ymin-text_height), label, fill="white", font=font)
48
- except Exception as e:
49
- st.error(f"Error drawing boxes: {str(e)}")
50
- return image
51
 
52
  def main():
53
- st.set_page_config(page_title="Plant Disease Detector", layout="wide")
54
- st.title("🌱 Plant Disease Detection (Tomato or Corn Maiza)")
55
 
56
- model = load_model()
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # File uploader
59
- uploaded_file = st.file_uploader("Upload a plant image...", type=["jpg", "jpeg", "png"])
60
 
61
- if uploaded_file is not None:
62
- # Load with Pillow
63
  image = Image.open(uploaded_file).convert("RGB")
64
  col1, col2 = st.columns(2)
65
 
66
  with col1:
67
- st.image(image, caption="Original Image", use_column_width=True)
68
 
69
- # Process and predict
70
  with st.spinner("Analyzing..."):
71
- # Convert to tensor (Pillow-compatible preprocessing)
72
- transform = transforms.Compose([
73
- transforms.Resize(640),
74
- transforms.ToTensor(),
75
- ])
76
- input_tensor = transform(image).unsqueeze(0)
77
-
78
- # Predict
79
- with torch.no_grad():
80
- results = model(input_tensor)
81
 
82
- # Convert results to Pandas
83
- try:
84
- results_df = results.pandas().xyxy[0]
 
 
85
 
86
- # Draw boxes using Pillow
87
- output_image = image.copy()
88
- output_image = draw_boxes_pillow(output_image, results_df)
89
-
90
- with col2:
91
- st.image(output_image, caption="Detection Results", use_column_width=True)
92
- st.dataframe(results_df[['name', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']])
93
-
94
- except Exception as e:
95
- st.error(f"Prediction error: {str(e)}")
96
 
97
  if __name__ == "__main__":
98
  main()
 
1
  import streamlit as st
2
  import torch
3
  from torchvision import transforms
4
+ from PIL import Image
5
  import numpy as np
 
6
 
7
+ # Custom model class (replace with your actual architecture)
8
+ class PlantDiseaseClassifier(torch.nn.Module):
9
  def __init__(self, num_classes=2):
10
  super().__init__()
11
+ # Example architecture - REPLACE WITH YOUR ACTUAL MODEL
12
+ self.model = torch.nn.Sequential(
13
+ torch.nn.Conv2d(3, 16, kernel_size=3, padding=1),
14
+ torch.nn.ReLU(),
15
+ torch.nn.MaxPool2d(2),
16
+ torch.nn.Conv2d(16, 32, kernel_size=3, padding=1),
17
+ torch.nn.ReLU(),
18
+ torch.nn.MaxPool2d(2),
19
+ torch.nn.Flatten(),
20
+ torch.nn.Linear(32*56*56, num_classes) # Adjust input dimensions
21
+ )
22
+
23
  def forward(self, x):
24
  return self.model(x)
25
 
 
26
  @st.cache_resource
27
  def load_model():
28
+ model = PlantDiseaseClassifier(num_classes=2) # Update with your class count
29
  try:
30
  model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
31
+ st.success("Model loaded successfully!")
32
+ except Exception as e:
33
+ st.error(f"Error loading model: {str(e)}")
34
  return model
35
 
36
+ def predict(image, model, class_names):
37
+ """Run prediction and return top class"""
38
+ transform = transforms.Compose([
39
+ transforms.Resize(256),
40
+ transforms.CenterCrop(224),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
43
+ ])
44
+
45
+ input_tensor = transform(image).unsqueeze(0)
46
+
47
+ with torch.no_grad():
48
+ output = model(input_tensor)
49
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
50
+ top_prob, top_class = torch.topk(probabilities, 1)
51
+
52
+ return class_names[top_class.item()], top_prob.item()
 
 
 
 
 
53
 
54
  def main():
55
+ st.title("🌱 Plant Disease Classifier")
 
56
 
57
+ # Update with your actual class names and care tips
58
+ CLASS_NAMES = {
59
+ 0: "Healthy",
60
+ 1: "Late Blight",
61
+ 2: "Powdery Mildew" # Add all your classes
62
+ }
63
+
64
+ CARE_TIPS = {
65
+ "Healthy": ["Continue regular watering", "Monitor plant growth"],
66
+ "Late Blight": ["Remove infected leaves", "Apply fungicide"],
67
+ "Powdery Mildew": ["Improve air circulation", "Apply sulfur spray"]
68
+ }
69
 
70
+ model = load_model()
71
+ uploaded_file = st.file_uploader("Upload plant image", type=["jpg", "png", "jpeg"])
72
 
73
+ if uploaded_file and model is not None:
 
74
  image = Image.open(uploaded_file).convert("RGB")
75
  col1, col2 = st.columns(2)
76
 
77
  with col1:
78
+ st.image(image, caption="Uploaded Image", use_column_width=True)
79
 
 
80
  with st.spinner("Analyzing..."):
81
+ predicted_class, confidence = predict(image, model, CLASS_NAMES)
 
 
 
 
 
 
 
 
 
82
 
83
+ with col2:
84
+ if "healthy" in predicted_class.lower():
85
+ st.success(f"Prediction: {predicted_class} ({confidence*100:.1f}%)")
86
+ else:
87
+ st.error(f"Prediction: {predicted_class} ({confidence*100:.1f}%)")
88
 
89
+ st.subheader("Care Recommendations")
90
+ for tip in CARE_TIPS.get(predicted_class, ["No specific recommendations"]):
91
+ st.write(f"• {tip}")
 
 
 
 
 
 
 
92
 
93
  if __name__ == "__main__":
94
  main()
requirements.txt CHANGED
@@ -1,10 +1,5 @@
1
- torch>=2.2.2
2
- torchvision>=0.17.2
3
- streamlit>=1.29.0
4
  Pillow>=10.0.0
5
- PyYAML>=6.0.1
6
- pandas>=2.1.0
7
- scikit-learn>=1.3.0
8
- tqdm>=4.65.0
9
- numpy==1.26.4
10
- opencv-python-headless>=4.8.0.76 # Updated to available version
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ streamlit>=1.25.0
4
  Pillow>=10.0.0
5
+ numpy>=1.20.0