Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
st.set_page_config( | |
page_title="ArXiv Paper Classifier", | |
page_icon="📚", | |
) | |
st.title("ArXiv Paper Classifier") | |
st.markdown( | |
""" | |
This app classifies papers based on their abstract. | |
Enter the paper details and the model will predict the most likely topic categories. | |
""" | |
) | |
def load_model_and_tokenizer(): | |
model_path = "goldov/arxiv-classifier-debertav3" # TODO: change later | |
model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
return model, tokenizer, model.config.id2label | |
with st.spinner("Loading model... This may take a minute."): | |
model, tokenizer, id2label = load_model_and_tokenizer() | |
st.subheader("Paper Information") | |
with st.form(key="paper_form"): | |
title = st.text_input("Title", placeholder="Enter the paper title") | |
abstract = st.text_area("Abstract (optional)", placeholder="Enter the paper abstract (optional)") | |
submit_button = st.form_submit_button(label="Classify Paper") | |
def predict_topics(title, abstract=""): | |
if abstract: | |
text = f"Title: {title} Abstract: {abstract}" | |
else: | |
text = f"Title: {title}" | |
tokens_info = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt") | |
model.eval() | |
model.cpu() | |
with torch.no_grad(): | |
out = model(**tokens_info) | |
probs = torch.nn.functional.softmax(out.logits, dim=-1).squeeze(0) | |
sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
cutoff_idx = torch.where(cumulative_probs >= 0.95)[0][0].item() + 1 | |
results = [] | |
for i in range(cutoff_idx): | |
category = sorted_indices[i].item() | |
category = id2label[category] | |
probability = sorted_probs[i].item() | |
results.append((category, probability)) | |
return results | |
if submit_button: | |
if not title: | |
st.error("Please enter a paper title.") | |
else: | |
with st.spinner("Classifying..."): | |
results = predict_topics(title, abstract) | |
st.subheader("Prediction Results") | |
if abstract: | |
st.text(f"Classification based on title and abstract") | |
else: | |
st.text(f"Classification based on title") | |
categories = [r[0] for r in results] | |
probabilities = [r[1] for r in results] | |
formatted_probs = [f"{p:.2%}" for p in probabilities] | |
st.markdown("#### Top Categories") | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.markdown("**Category**") | |
with col2: | |
st.markdown("**Probability**") | |
for category, prob in results: | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.markdown(f"{category}") | |
with col2: | |
st.progress(prob) | |
st.markdown(f"{prob:.2%}") | |
total_prob = sum(probabilities) | |
st.info(f"Total probability covered: {total_prob:.2%}") | |
# Add example section | |
if st.button("Try An Example!"): | |
example_title = "Attention Is All You Need" | |
example_abstract = """The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. | |
The best performing models also connect the encoder and decoder through an attention mechanism. | |
We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. | |
Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. | |
Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. | |
On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. | |
We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.""" | |
with st.spinner("Classifying example..."): | |
results = predict_topics(example_title, example_abstract) | |
st.subheader("Example Prediction Results") | |
st.text(f"Title: {example_title}") | |
st.text(f"Abstract: {example_abstract}") | |
st.text("Classification based on title and abstract") | |
probabilities = [r[1] for r in results] | |
st.markdown("#### Top Categories") | |
# Create a more visually appealing table | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.markdown("**Category**") | |
with col2: | |
st.markdown("**Probability**") | |
for category, prob in results: | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.markdown(f"{category}") | |
with col2: | |
st.progress(prob) | |
st.markdown(f"{prob:.1%}") | |
total_prob = sum(probabilities) | |
st.info(f"Total probability covered: {total_prob:.1%}") | |
st.markdown("---") | |
st.markdown("ArXiv Paper Classifier by Ivan Goldov") | |