hysts HF Staff commited on
Commit
52de44b
·
1 Parent(s): f69892b
Files changed (2) hide show
  1. app.py +9 -5
  2. semantic_search.py +13 -6
app.py CHANGED
@@ -127,8 +127,9 @@ def update_df(
127
  except pl.exceptions.ComputeError as e:
128
  raise gr.Error(str(e)) from e
129
  else:
130
- paper_ids = semantic_search(search_query, candidate_pool_size, score_threshold)
131
- df = df.filter(pl.col("paper_id").is_in(paper_ids))
 
132
 
133
  if presentation_type != "(ALL)":
134
  df = df.filter(pl.col("Type").str.contains(presentation_type))
@@ -156,12 +157,15 @@ with gr.Blocks(css_paths="style.css") as demo:
156
  choices=["Semantic Search", "Title Search"],
157
  value="Semantic Search",
158
  show_label=False,
 
159
  )
160
  search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Enter query here")
161
- with gr.Accordion(label="Advanced Search Options", open=False, visible=True) as advanced_search_options:
162
  with gr.Row():
163
- candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=200, step=1, value=100)
164
- score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.01, value=0.7)
 
 
165
 
166
  presentation_type = gr.Radio(
167
  label="Presentation Type",
 
127
  except pl.exceptions.ComputeError as e:
128
  raise gr.Error(str(e)) from e
129
  else:
130
+ paper_ids, scores = semantic_search(search_query, candidate_pool_size, score_threshold)
131
+ df = pl.DataFrame({"paper_id": paper_ids, "score": scores}).join(df, on="paper_id", how="inner")
132
+ df = df.sort("score", descending=True).drop("score")
133
 
134
  if presentation_type != "(ALL)":
135
  df = df.filter(pl.col("Type").str.contains(presentation_type))
 
157
  choices=["Semantic Search", "Title Search"],
158
  value="Semantic Search",
159
  show_label=False,
160
+ info="Note: Semantic search consumes your ZeroGPU quota.",
161
  )
162
  search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Enter query here")
163
+ with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options:
164
  with gr.Row():
165
+ candidate_pool_size = gr.Slider(
166
+ label="Candidate Pool Size", minimum=1, maximum=1000, step=1, value=300
167
+ )
168
+ score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.01, value=0.5)
169
 
170
  presentation_type = gr.Radio(
171
  label="Presentation Type",
semantic_search.py CHANGED
@@ -16,7 +16,9 @@ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
16
 
17
 
18
  @spaces.GPU(duration=5)
19
- def semantic_search(query: str, candidate_pool_size: int = 100, score_threshold: float = 0.7) -> list[int]:
 
 
20
  query_vec = model.encode(query)
21
  _, retrieved_data = ds.get_nearest_examples("embedding", query_vec, k=candidate_pool_size)
22
 
@@ -27,8 +29,13 @@ def semantic_search(query: str, candidate_pool_size: int = 100, score_threshold:
27
  rerank_scores = reranker.predict(rerank_inputs)
28
  sorted_indices = np.argsort(rerank_scores)[::-1]
29
 
30
- return [
31
- retrieved_data["paper_id"][i]
32
- for i in sorted_indices
33
- if scipy.special.expit(rerank_scores[i]) >= score_threshold
34
- ]
 
 
 
 
 
 
16
 
17
 
18
  @spaces.GPU(duration=5)
19
+ def semantic_search(
20
+ query: str, candidate_pool_size: int = 300, score_threshold: float = 0.5
21
+ ) -> tuple[list[int], list[float]]:
22
  query_vec = model.encode(query)
23
  _, retrieved_data = ds.get_nearest_examples("embedding", query_vec, k=candidate_pool_size)
24
 
 
29
  rerank_scores = reranker.predict(rerank_inputs)
30
  sorted_indices = np.argsort(rerank_scores)[::-1]
31
 
32
+ paper_ids = []
33
+ scores = []
34
+ for i in sorted_indices:
35
+ score = float(scipy.special.expit(rerank_scores[i]))
36
+ if score < score_threshold:
37
+ break
38
+ paper_ids.append(retrieved_data["paper_id"][i])
39
+ scores.append(score)
40
+
41
+ return paper_ids, scores