ICLR2025 / app_pr.py
hysts's picture
hysts HF Staff
Update
1c00c70
raw
history blame contribute delete
12.2 kB
import datetime
import difflib
import json
import re
import tempfile
import gradio as gr
import polars as pl
from gradio_modal import Modal
from huggingface_hub import CommitOperationAdd, HfApi
from table import PATCH_REPO_ID, df_orig
# TODO: remove this once https://github.com/gradio-app/gradio/issues/11022 is fixed # noqa: FIX002, TD002
NOTE = """\
#### ⚠️ Note
You may encounter an issue when selecting table data after using the search bar.
This is due to a known bug in Gradio.
The issue typically occurs when multiple rows remain after filtering.
If only one row remains, the selection should work as expected.
"""
api = HfApi()
PR_VIEW_COLUMNS = [
"title",
"authors_str",
"openreview_md",
"arxiv_id",
"github_md",
"Spaces",
"Models",
"Datasets",
"paper_id",
]
PR_RAW_COLUMNS = [
"paper_id",
"title",
"authors",
"arxiv_id",
"project_page",
"github",
"space_ids",
"model_ids",
"dataset_ids",
]
df_pr_view = df_orig.with_columns(pl.lit("📝").alias("Fix")).select(["Fix", *PR_VIEW_COLUMNS])
df_pr_view = df_pr_view.with_columns(pl.col("arxiv_id").fill_null(""))
df_pr_raw = df_orig.select(PR_RAW_COLUMNS)
def df_pr_row_selected(
evt: gr.SelectData,
) -> tuple[
Modal,
gr.Textbox, # title
gr.Textbox, # authors
gr.Textbox, # arxiv_id
gr.Textbox, # project_page
gr.Textbox, # github
gr.Textbox, # space_ids
gr.Textbox, # model_ids
gr.Textbox, # dataset_ids
dict | None, # original_data
]:
if evt.value != "📝":
return (
Modal(),
gr.Textbox(), # title
gr.Textbox(), # authors
gr.Textbox(), # arxiv_id
gr.Textbox(), # project_page
gr.Textbox(), # github
gr.Textbox(), # space_ids
gr.Textbox(), # model_ids
gr.Textbox(), # dataset_ids
None, # original_data
)
paper_id = evt.row_value[-1]
row = df_pr_raw.filter(pl.col("paper_id") == paper_id)
original_data = row.to_dicts()[0]
authors = original_data["authors"]
space_ids = original_data["space_ids"]
model_ids = original_data["model_ids"]
dataset_ids = original_data["dataset_ids"]
return (
Modal(visible=True),
gr.Textbox(value=row["title"].item()), # title
gr.Textbox(value="\n".join(authors)), # authors
gr.Textbox(value=row["arxiv_id"].item()), # arxiv_id
gr.Textbox(value=row["project_page"].item()), # project_page
gr.Textbox(value=row["github"].item()), # github
gr.Textbox(value="\n".join(space_ids)), # space_ids
gr.Textbox(value="\n".join(model_ids)), # model_ids
gr.Textbox(value="\n".join(dataset_ids)), # dataset_ids
original_data, # original_data
)
URL_PATTERN = re.compile(r"^(https?://)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(:\d+)?(/.*)?$")
GITHUB_PATTERN = re.compile(r"^https://github\.com/[^/\s]+/[^/\s]+(/tree/[^/\s]+/[^/\s].*)?$")
REPO_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$")
ARXIV_ID_PATTERN = re.compile(r"^\d{4}\.\d{4,5}$")
def is_valid_url(url: str) -> bool:
return URL_PATTERN.match(url) is not None
def is_valid_github_url(url: str) -> bool:
return GITHUB_PATTERN.match(url) is not None
def is_valid_repo_id(repo_id: str) -> bool:
return REPO_ID_PATTERN.match(repo_id) is not None
def is_valid_arxiv_id(arxiv_id: str) -> bool:
return ARXIV_ID_PATTERN.match(arxiv_id) is not None
def validate_pr_data(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids: list[str],
model_ids: list[str],
dataset_ids: list[str],
) -> None:
if not title_pr:
raise gr.Error("Title cannot be empty", print_exception=False)
if not authors_pr:
raise gr.Error("Authors cannot be empty", print_exception=False)
if arxiv_id_pr and not is_valid_arxiv_id(arxiv_id_pr):
raise gr.Error(
"Invalid arXiv ID format. Expected format: 'YYYY.NNNNN' (e.g., '2023.01234')", print_exception=False
)
if project_page_pr and not is_valid_url(project_page_pr):
raise gr.Error("Project page must be a valid URL", print_exception=False)
if github_pr and not is_valid_github_url(github_pr):
raise gr.Error("GitHub must be a valid GitHub URL", print_exception=False)
for repo_id in space_ids + model_ids + dataset_ids:
if not is_valid_repo_id(repo_id):
error_msg = f"Space/Model/Dataset ID must be in the format 'org_name/repo_name'. Got: {repo_id}"
raise gr.Error(error_msg, print_exception=False)
def format_submitted_data(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids_pr: str,
model_ids_pr: str,
dataset_ids_pr: str,
) -> dict:
space_ids = [repo_id for repo_id in space_ids_pr.split("\n") if repo_id.strip()]
model_ids = [repo_id for repo_id in model_ids_pr.split("\n") if repo_id.strip()]
dataset_ids = [repo_id for repo_id in dataset_ids_pr.split("\n") if repo_id.strip()]
validate_pr_data(title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, space_ids, model_ids, dataset_ids)
return {
"title": title_pr,
"authors": [a for a in authors_pr.split("\n") if a.strip()],
"arxiv_id": arxiv_id_pr if arxiv_id_pr else None,
"project_page": project_page_pr if project_page_pr else None,
"github": github_pr if github_pr else None,
"space_ids": space_ids,
"model_ids": model_ids,
"dataset_ids": dataset_ids,
}
def preview_diff(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids_pr: str,
model_ids_pr: str,
dataset_ids_pr: str,
original_data: dict,
) -> tuple[gr.Markdown, gr.Button]:
submitted_data = format_submitted_data(
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
space_ids_pr,
model_ids_pr,
dataset_ids_pr,
)
submitted_data = {"paper_id": original_data["paper_id"], **submitted_data}
original_json = json.dumps(original_data, indent=2)
submitted_json = json.dumps(submitted_data, indent=2)
diff = difflib.unified_diff(
original_json.splitlines(),
submitted_json.splitlines(),
fromfile="before",
tofile="after",
lineterm="",
)
diff_str = "\n".join(diff)
return gr.Markdown(value=f"```diff\n{diff_str}\n```"), gr.Button(visible=True)
def open_pr(
title_pr: str,
authors_pr: str,
arxiv_id_pr: str,
project_page_pr: str,
github_pr: str,
space_ids_pr: str,
model_ids_pr: str,
dataset_ids_pr: str,
original_data: dict,
oauth_token: gr.OAuthToken | None,
) -> gr.Markdown:
submitted_data = format_submitted_data(
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
space_ids_pr,
model_ids_pr,
dataset_ids_pr,
)
diff_dict = {key: submitted_data[key] for key in submitted_data if submitted_data[key] != original_data[key]}
if not diff_dict:
gr.Info("No data to submit")
return ""
paper_id = original_data["paper_id"]
diff_dict["paper_id"] = paper_id
original_json = json.dumps(original_data, indent=2)
submitted_json = json.dumps(submitted_data, indent=2)
diff = "\n".join(
difflib.unified_diff(
original_json.splitlines(),
submitted_json.splitlines(),
fromfile="before",
tofile="after",
lineterm="",
)
)
diff_dict["diff"] = diff
timestamp = datetime.datetime.now(datetime.timezone.utc)
diff_dict["timestamp"] = timestamp.isoformat()
with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f:
json.dump(diff_dict, f, indent=2)
f.flush()
commit = CommitOperationAdd(f"data/{paper_id}--{timestamp.strftime('%Y-%m-%d-%H-%M-%S')}.json", f.name)
res = api.create_commit(
repo_id=PATCH_REPO_ID,
operations=[commit],
commit_message=f"Update {paper_id}",
repo_type="dataset",
create_pr=True,
token=oauth_token.token if oauth_token else None,
)
return gr.Markdown(value=res.pr_url, visible=True)
def render_open_pr_page(profile: gr.OAuthProfile | None) -> dict:
return gr.Column(visible=profile is not None)
with gr.Blocks() as demo:
gr.LoginButton()
with gr.Column(visible=False) as open_pr_col:
gr.Markdown(NOTE)
df_pr = gr.Dataframe(
value=df_pr_view,
datatype=[
"str", # Fix
"str", # Title
"str", # Authors
"markdown", # openreview
"str", # arxiv_id
"markdown", # github
"markdown", # spaces
"markdown", # models
"markdown", # datasets
"str", # paper id
],
column_widths=[
"50px", # Fix
"40%", # Title
"20%", # Authors
None, # openreview
"100px", # arxiv_id
None, # github
None, # spaces
None, # models
None, # datasets
None, # paper id
],
type="polars",
row_count=(0, "dynamic"),
interactive=False,
max_height=1000,
show_search="search",
)
with Modal(visible=False) as pr_modal:
with gr.Group():
title_pr = gr.Textbox(label="Title")
authors_pr = gr.Textbox(label="Authors")
arxiv_id_pr = gr.Textbox(label="arXiv ID")
project_page_pr = gr.Textbox(label="Project page")
github_pr = gr.Textbox(label="GitHub")
spaces_pr = gr.Textbox(
label="Spaces",
info="Enter one space ID (e.g., 'org_name/space_name') per line.",
)
models_pr = gr.Textbox(
label="Models",
info="Enter one model ID (e.g., 'org_name/model_name') per line.",
)
datasets_pr = gr.Textbox(
label="Datasets",
info="Enter one dataset ID (e.g., 'org_name/dataset_name') per line.",
)
original_data = gr.State()
preview_diff_button = gr.Button("Preview diff")
diff_view = gr.Markdown()
open_pr_button = gr.Button("Open PR", visible=False)
pr_url = gr.Markdown(visible=False)
pr_modal.blur(
fn=lambda: (None, gr.Button(visible=False), gr.Markdown(visible=False)),
outputs=[diff_view, open_pr_button, pr_url],
)
df_pr.select(
fn=df_pr_row_selected,
outputs=[
pr_modal,
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
spaces_pr,
models_pr,
datasets_pr,
original_data,
],
)
preview_diff_button.click(
fn=preview_diff,
inputs=[
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
spaces_pr,
models_pr,
datasets_pr,
original_data,
],
outputs=[diff_view, open_pr_button],
)
open_pr_button.click(
fn=open_pr,
inputs=[
title_pr,
authors_pr,
arxiv_id_pr,
project_page_pr,
github_pr,
spaces_pr,
models_pr,
datasets_pr,
original_data,
],
outputs=pr_url,
)
demo.load(fn=render_open_pr_page, outputs=open_pr_col)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False)