farrell236's picture
Update app.py
e9dc61f
raw
history blame contribute delete
5.03 kB
import cv2
import numpy as np
import streamlit as st
import tensorflow as tf
from utils import _get_retina_bb, _pad_to_square
@st.cache_resource
def load_model(model_file):
model = tf.keras.models.load_model(model_file, compile=False)
print(f'Model {model_file} Loaded!')
return model
@st.cache_resource
def load_gatekeeper():
validator_model = tf.keras.models.load_model('checkpoints/ResNetV2-EyeQ-QA.tf')
print('Gatekeeper Model Loaded!')
return validator_model
def parse_function(image):
image = tf.image.resize(image, [512, 512])
image = tf.image.convert_image_dtype(image, tf.float32)
return image
def main():
st.title('Retina Segmentation')
st.sidebar.title('Segmentation Model')
options = st.sidebar.selectbox('Select Option:', ('Vessels', 'Lesions (BETA)'))
gatekeeper = st.sidebar.radio("Gatekeeper:", ('Enabled', 'Disabled'))
gatekeeper_model = load_gatekeeper()
if options == 'Vessels':
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
model = load_model('checkpoints/DeeplabV3Plus_DRIVE.tf')
if uploaded_file:
col1, col2 = st.columns(2)
# Load Image
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Check image
valid = np.argmax(gatekeeper_model(parse_function(image[None, ...])))
if valid == 2 and gatekeeper == 'Enabled':
st.image(image)
st.info('Image is of poor quality')
return
# Localise and center retina image
x, y, w, h, _ = _get_retina_bb(image)
image = image[y:y + h, x:x + w, :]
image = _pad_to_square(image, border=0)
image = cv2.resize(image, (1024, 1024))
with col1:
st.subheader("Uploaded Image")
st.image(image)
# Apply CLAHE pre-processing
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
image[:, :, 0] = clahe.apply(image[:, :, 0])
image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image = tf.image.convert_image_dtype(image, tf.float32)
# Run model on input
y_pred = model(image[None, ..., None])[0].numpy()
with col2:
st.subheader("Predicted Vessel")
st.image(y_pred)
elif options == 'Lesions (BETA)':
st.write('```--- WARNING: This model is highly experimental ---```')
st.set_option('deprecation.showfileUploaderEncoding', False)
uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
model = load_model('checkpoints/DeeplabV3Plus_FGADR.tf')
if uploaded_file:
col1, col2, col3, = st.columns(3)
# Load Image
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Check image
valid = np.argmax(gatekeeper_model(parse_function(image[None, ...])))
if valid == 2 and gatekeeper == 'Enabled':
st.image(image)
st.info('Image is of poor quality')
return
# Localise and center retina image
x, y, w, h, _ = _get_retina_bb(image)
image = image[y:y + h, x:x + w, :]
image = _pad_to_square(image, border=0)
image = cv2.resize(image, (1024, 1024))
with col1:
st.subheader("Uploaded Image")
st.image(image)
# Apply CLAHE pre-processing
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
image[:, :, 0] = clahe.apply(image[:, :, 0])
image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB)
image = tf.image.convert_image_dtype(image, tf.float32)
# Run model on input
y_pred = model(image[None, ..., None])[0].numpy()
with col2:
st.subheader(f'MA')
st.image(y_pred[..., 1])
with col3:
st.subheader(f'HE')
st.image(y_pred[..., 2])
with col1:
st.subheader(f'EX')
st.image(y_pred[..., 3])
with col2:
st.subheader(f'SE')
st.image(y_pred[..., 4])
with col3:
st.subheader(f'OD')
st.image(y_pred[..., 5])
if __name__ == '__main__':
tf.config.set_visible_devices([], 'GPU')
main()