goldov commited on
Commit
82d81c8
·
verified ·
1 Parent(s): 5c83dcb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+ st.set_page_config(
6
+ page_title="ArXiv Paper Classifier",
7
+ page_icon="📚",
8
+ )
9
+
10
+ st.title("ArXiv Paper Classifier")
11
+ st.markdown(
12
+ """
13
+ This app classifies papers based on their abstract.
14
+ Enter the paper details and the model will predict the most likely topic categories.
15
+ """
16
+ )
17
+
18
+
19
+ @st.cache_resource
20
+ def load_model_and_tokenizer():
21
+ model_path = "goldov/arxiv-classifier-debertav3" # TODO: change later
22
+
23
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
25
+
26
+ return model, tokenizer, model.config.id2label
27
+
28
+
29
+ with st.spinner("Loading model... This may take a minute."):
30
+ model, tokenizer, id2label = load_model_and_tokenizer()
31
+
32
+
33
+ st.subheader("Paper Information")
34
+ with st.form(key="paper_form"):
35
+ title = st.text_input("Title", placeholder="Enter the paper title")
36
+ abstract = st.text_area("Abstract (optional)", placeholder="Enter the paper abstract (optional)")
37
+ submit_button = st.form_submit_button(label="Classify Paper")
38
+
39
+
40
+ def predict_topics(title, abstract=""):
41
+
42
+ if abstract:
43
+ text = f"Title: {title} Abstract: {abstract}"
44
+ else:
45
+ text = f"Title: {title}"
46
+
47
+ tokens_info = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")
48
+
49
+ model.eval()
50
+ model.cpu()
51
+ with torch.no_grad():
52
+ out = model(**tokens_info)
53
+ probs = torch.nn.functional.softmax(out.logits, dim=-1).squeeze(0)
54
+
55
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
56
+
57
+ cumulative_probs = torch.cumsum(sorted_probs, dim=0)
58
+
59
+ cutoff_idx = torch.where(cumulative_probs >= 0.95)[0][0].item() + 1
60
+
61
+ results = []
62
+ for i in range(cutoff_idx):
63
+ category = sorted_indices[i].item()
64
+ category = id2label[category]
65
+ probability = sorted_probs[i].item()
66
+ results.append((category, probability))
67
+
68
+ return results
69
+
70
+
71
+ if submit_button:
72
+ if not title:
73
+ st.error("Please enter a paper title.")
74
+ else:
75
+ with st.spinner("Classifying..."):
76
+ results = predict_topics(title, abstract)
77
+
78
+ st.subheader("Prediction Results")
79
+
80
+ if abstract:
81
+ st.text(f"Classification based on title and abstract")
82
+ else:
83
+ st.text(f"Classification based on title")
84
+
85
+ categories = [r[0] for r in results]
86
+ probabilities = [r[1] for r in results]
87
+
88
+ formatted_probs = [f"{p:.2%}" for p in probabilities]
89
+
90
+ st.markdown("#### Top Categories")
91
+
92
+ col1, col2 = st.columns([3, 1])
93
+ with col1:
94
+ st.markdown("**Category**")
95
+ with col2:
96
+ st.markdown("**Probability**")
97
+
98
+ for category, prob in results:
99
+ col1, col2 = st.columns([3, 1])
100
+ with col1:
101
+ st.markdown(f"{category}")
102
+ with col2:
103
+ st.progress(prob)
104
+ st.markdown(f"{prob:.2%}")
105
+
106
+ total_prob = sum(probabilities)
107
+ st.info(f"Total probability covered: {total_prob:.2%}")
108
+
109
+
110
+ # Add example section
111
+ if st.button("Try An Example!"):
112
+ example_title = "Attention Is All You Need"
113
+ example_abstract = """The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration.
114
+ The best performing models also connect the encoder and decoder through an attention mechanism.
115
+ We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely.
116
+ Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train.
117
+ Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU.
118
+ On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature.
119
+ We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data."""
120
+
121
+ with st.spinner("Classifying example..."):
122
+ results = predict_topics(example_title, example_abstract)
123
+ st.subheader("Example Prediction Results")
124
+ st.text(f"Title: {example_title}")
125
+ st.text(f"Abstract: {example_abstract}")
126
+ st.text("Classification based on title and abstract")
127
+
128
+ probabilities = [r[1] for r in results]
129
+ st.markdown("#### Top Categories")
130
+
131
+ # Create a more visually appealing table
132
+ col1, col2 = st.columns([3, 1])
133
+ with col1:
134
+ st.markdown("**Category**")
135
+ with col2:
136
+ st.markdown("**Probability**")
137
+
138
+ for category, prob in results:
139
+ col1, col2 = st.columns([3, 1])
140
+ with col1:
141
+ st.markdown(f"{category}")
142
+ with col2:
143
+ st.progress(prob)
144
+ st.markdown(f"{prob:.1%}")
145
+
146
+ total_prob = sum(probabilities)
147
+ st.info(f"Total probability covered: {total_prob:.1%}")
148
+
149
+ st.markdown("---")
150
+ st.markdown("ArXiv Paper Classifier by Ivan Goldov")