File size: 2,527 Bytes
7cf32e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import numpy as np
import torch
from PIL import Image
import open_clip
from datasets import Dataset
import os

# Set environment variable to work around OpenMP runtime issue
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# Load the model and processor
model, processor = open_clip.create_model_from_pretrained('hf-hub:imageomics/bioclip')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load the dataset
embedding_path = "./data/embeddings_bioclip_False"
ds = Dataset.load_from_disk(embedding_path)

# Load FAISS indexes
cosine_faiss_path = os.path.join(embedding_path, "embeddings_cosine.faiss")
l2_faiss_path = os.path.join(embedding_path, "embeddings_l2.faiss")
ds.load_faiss_index("embeddings_cosine", cosine_faiss_path)
ds.load_faiss_index("embeddings_l2", l2_faiss_path)

def majority_vote(classes, scores=None):
    if scores is None:
        scores = np.ones_like(classes)
    unique_classes, class_counts = np.unique(classes, return_counts=True)
    class_weights = {cls: 0 for cls in unique_classes}

    for cls, weight in zip(classes, scores):
        class_weights[cls] += weight

    majority_class = max(class_weights, key=class_weights.get)
    return majority_class

def classify_example(example, index="embeddings_l2", k=10, vote_scores=True):
    features = np.array(example["embeddings"], dtype=np.float32)
    scores, nearest = ds.get_nearest_examples(index, features, k)

    class_labels = [ds.features["label"].names[c] for c in nearest["label"]]

    if vote_scores:
        prediction = majority_vote(class_labels, scores)
    else:
        prediction = majority_vote(class_labels)
    
    return prediction, class_labels, nearest["file"]

def embed_image(image: Image.Image):
    processed_images = processor(image).unsqueeze(0)

    with torch.no_grad():
        embeddings = model.encode_image(processed_images.to(device))

    return {"embeddings": embeddings.cpu()}

def predict(image):
    embedding = embed_image(image)
    prediction, class_labels, file_paths = classify_example(embedding)
    
    return prediction, ", ".join(class_labels[:3]), ", ".join(file_paths[:3])

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Textbox(label="Prediction"),
        gr.Textbox(label="Top 3 Classes"),
        gr.Textbox(label="Top 3 File Paths")
    ],
    title="BioClip Image Classification",
    description="Upload an image to get a prediction using the BioClip model."
)

iface.launch()