Spaces:
Running
on
Zero
Running
on
Zero
Update
Browse files- app.py +9 -5
- semantic_search.py +13 -6
app.py
CHANGED
@@ -127,8 +127,9 @@ def update_df(
|
|
127 |
except pl.exceptions.ComputeError as e:
|
128 |
raise gr.Error(str(e)) from e
|
129 |
else:
|
130 |
-
paper_ids = semantic_search(search_query, candidate_pool_size, score_threshold)
|
131 |
-
df =
|
|
|
132 |
|
133 |
if presentation_type != "(ALL)":
|
134 |
df = df.filter(pl.col("Type").str.contains(presentation_type))
|
@@ -156,12 +157,15 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
156 |
choices=["Semantic Search", "Title Search"],
|
157 |
value="Semantic Search",
|
158 |
show_label=False,
|
|
|
159 |
)
|
160 |
search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Enter query here")
|
161 |
-
with gr.Accordion(label="Advanced Search Options", open=False
|
162 |
with gr.Row():
|
163 |
-
candidate_pool_size = gr.Slider(
|
164 |
-
|
|
|
|
|
165 |
|
166 |
presentation_type = gr.Radio(
|
167 |
label="Presentation Type",
|
|
|
127 |
except pl.exceptions.ComputeError as e:
|
128 |
raise gr.Error(str(e)) from e
|
129 |
else:
|
130 |
+
paper_ids, scores = semantic_search(search_query, candidate_pool_size, score_threshold)
|
131 |
+
df = pl.DataFrame({"paper_id": paper_ids, "score": scores}).join(df, on="paper_id", how="inner")
|
132 |
+
df = df.sort("score", descending=True).drop("score")
|
133 |
|
134 |
if presentation_type != "(ALL)":
|
135 |
df = df.filter(pl.col("Type").str.contains(presentation_type))
|
|
|
157 |
choices=["Semantic Search", "Title Search"],
|
158 |
value="Semantic Search",
|
159 |
show_label=False,
|
160 |
+
info="Note: Semantic search consumes your ZeroGPU quota.",
|
161 |
)
|
162 |
search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Enter query here")
|
163 |
+
with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options:
|
164 |
with gr.Row():
|
165 |
+
candidate_pool_size = gr.Slider(
|
166 |
+
label="Candidate Pool Size", minimum=1, maximum=1000, step=1, value=300
|
167 |
+
)
|
168 |
+
score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.01, value=0.5)
|
169 |
|
170 |
presentation_type = gr.Radio(
|
171 |
label="Presentation Type",
|
semantic_search.py
CHANGED
@@ -16,7 +16,9 @@ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
|
16 |
|
17 |
|
18 |
@spaces.GPU(duration=5)
|
19 |
-
def semantic_search(
|
|
|
|
|
20 |
query_vec = model.encode(query)
|
21 |
_, retrieved_data = ds.get_nearest_examples("embedding", query_vec, k=candidate_pool_size)
|
22 |
|
@@ -27,8 +29,13 @@ def semantic_search(query: str, candidate_pool_size: int = 100, score_threshold:
|
|
27 |
rerank_scores = reranker.predict(rerank_inputs)
|
28 |
sorted_indices = np.argsort(rerank_scores)[::-1]
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
@spaces.GPU(duration=5)
|
19 |
+
def semantic_search(
|
20 |
+
query: str, candidate_pool_size: int = 300, score_threshold: float = 0.5
|
21 |
+
) -> tuple[list[int], list[float]]:
|
22 |
query_vec = model.encode(query)
|
23 |
_, retrieved_data = ds.get_nearest_examples("embedding", query_vec, k=candidate_pool_size)
|
24 |
|
|
|
29 |
rerank_scores = reranker.predict(rerank_inputs)
|
30 |
sorted_indices = np.argsort(rerank_scores)[::-1]
|
31 |
|
32 |
+
paper_ids = []
|
33 |
+
scores = []
|
34 |
+
for i in sorted_indices:
|
35 |
+
score = float(scipy.special.expit(rerank_scores[i]))
|
36 |
+
if score < score_threshold:
|
37 |
+
break
|
38 |
+
paper_ids.append(retrieved_data["paper_id"][i])
|
39 |
+
scores.append(score)
|
40 |
+
|
41 |
+
return paper_ids, scores
|