coex-prj / run.py
harheem's picture
Upload folder using huggingface_hub
d2941e6 verified
raw
history blame contribute delete
11 kB
import pandas as pd
import numpy as np
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from rank_bm25 import BM25Okapi
from kiwipiepy import Kiwi
from typing import List
import gradio as gr
class ProductSearchSystem:
def __init__(self,
model_name: str = "snunlp/KR-SBERT-V40K-klueNLI-augSTS",
bm25_weight: float = 0.3,
vector_weight: float = 0.7):
"""검색 μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™”"""
self.embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
self.bm25_weight = bm25_weight
self.vector_weight = vector_weight
self.vector_store = None
self.bm25 = None
self.documents = []
self.df = None
# Kiwi ν† ν¬λ‚˜μ΄μ € μ΄ˆκΈ°ν™”
self.kiwi = Kiwi()
def _tokenize_text(self, text: str) -> List[str]:
"""Kiwiλ₯Ό μ‚¬μš©ν•œ ν…μŠ€νŠΈ ν† ν¬λ‚˜μ΄μ§•"""
# ν˜•νƒœμ†Œ 뢄석 μˆ˜ν–‰
tokens = self.kiwi.tokenize(text)
# λͺ…사, 동사, ν˜•μš©μ‚¬λ§Œ μΆ”μΆœ
pos_tags = ['NNG', 'NNP', 'VV', 'VA', 'SL'] # 일반λͺ…사, 고유λͺ…사, 동사, ν˜•μš©μ‚¬
return [token.form for token in tokens if token.tag in pos_tags] # posλ₯Ό tag둜 λ³€κ²½
def load_sample_data(self):
"""μƒ˜ν”Œ 데이터 λ‘œλ“œ"""
self.df = pd.read_csv("sample_data.csv")
self._preprocess_data()
self._create_search_index()
return True
def _preprocess_data(self):
"""데이터 μ „μ²˜λ¦¬"""
# 빈 κ°’ 처리
self.df['category'] = self.df['category'].fillna('λ―ΈλΆ„λ₯˜')
# 특수 문자 처리
self.df['company_info'] = self.df['company_info'].fillna('')
self.df['company_info'] = self.df['company_info'].str.replace('_x000D_', '\n')
self.df['description'] = self.df['description'].fillna('')
self.df['description'] = self.df['description'].str.replace('_x000D_', '\n')
# λΆˆν•„μš”ν•œ 곡백 제거
for col in self.df.columns:
if self.df[col].dtype == 'object':
self.df[col] = self.df[col].str.strip()
def _create_search_index(self):
"""검색 인덱슀 생성"""
self.documents = []
tokenized_documents = [] # BM25용 ν† ν°ν™”λœ λ¬Έμ„œ
for _, row in self.df.iterrows():
content = f"{row['company_name']} {row['category']} {row['company_info']} {row['product_name']} {row['description']}"
# Kiwi ν† ν¬λ‚˜μ΄μ €λ₯Ό μ‚¬μš©ν•œ 토큰화
tokenized_doc = self._tokenize_text(content)
tokenized_documents.append(tokenized_doc)
self.documents.append(
Document(
page_content=content,
metadata={
'company_name': row['company_name'],
'category': row['category'],
'company_info': row['company_info'],
'product_name': row['product_name'],
'description': row['description']
}
)
)
# BM25 인덱슀 생성
self.bm25 = BM25Okapi(tokenized_documents)
# 벑터 μŠ€ν† μ–΄ 생성
self.vector_store = FAISS.from_documents(self.documents, self.embeddings)
def search(self, query: str, top_k: int = 3) -> List[dict]:
"""검색 μ‹€ν–‰"""
if not query.strip():
return []
# BM25 검색 - Kiwi ν† ν¬λ‚˜μ΄μ € μ‚¬μš©
tokenized_query = self._tokenize_text(query)
bm25_scores = self.bm25.get_scores(tokenized_query)
# 벑터 검색
query_embedding = self.embeddings.embed_query(query)
vector_docs_and_scores = self.vector_store.similarity_search_with_score(query, k=len(self.documents))
# κ²°κ³Ό 톡합 및 점수 계산
results = []
seen_products = set()
# 점수 μ •κ·œν™”λ₯Ό μœ„ν•œ μ΅œλŒ€κ°’
max_bm25 = max(bm25_scores) if len(bm25_scores) > 0 else 1
max_vector = max(score for _, score in vector_docs_and_scores) if vector_docs_and_scores else 1
for i, doc in enumerate(self.documents):
# μ •κ·œν™”λœ 점수 계산
bm25_score = bm25_scores[i] / max_bm25 if max_bm25 > 0 else 0
vector_score = None
# ν•΄λ‹Ή λ¬Έμ„œμ˜ 벑터 점수 μ°ΎκΈ°
for vec_doc, vec_score in vector_docs_and_scores:
if vec_doc.page_content == doc.page_content:
vector_score = (1 - (vec_score / max_vector)) if max_vector > 0 else 0
break
if vector_score is not None:
# μ΅œμ’… 점수 계산
final_score = (self.bm25_weight * bm25_score) + (self.vector_weight * vector_score)
product_key = f"{doc.metadata['company_name']}-{doc.metadata['product_name']}"
if product_key not in seen_products:
results.append({
'company_name': doc.metadata['company_name'],
'category': doc.metadata['category'],
'company_info': doc.metadata['company_info'],
'product_name': doc.metadata['product_name'],
'description': doc.metadata['description'],
'bm25_score': round(bm25_score, 3),
'vector_score': round(vector_score, 3),
'final_score': round(final_score, 3)
})
seen_products.add(product_key)
# μ΅œμ’… 점수둜 μ •λ ¬
results.sort(key=lambda x: x['final_score'], reverse=True)
return results[:top_k]
def create_gradio_interface():
"""Gradio μΈν„°νŽ˜μ΄μŠ€ 생성"""
# 검색 μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™” 및 μƒ˜ν”Œ 데이터 λ‘œλ“œ
search_system = ProductSearchSystem()
search_system.load_sample_data()
def search_products(query: str,
top_k: int,
bm25_weight: float) -> tuple:
"""검색 μ‹€ν–‰ 및 κ²°κ³Ό ν¬λ§€νŒ…"""
# κ°€μ€‘μΉ˜ μ—…λ°μ΄νŠΈ
search_system.bm25_weight = bm25_weight
search_system.vector_weight = 1 - bm25_weight
# 검색 μ‹€ν–‰
results = search_system.search(query, top_k=top_k)
# κ²°κ³Όλ₯Ό ν‘œ ν˜•μ‹μœΌλ‘œ λ³€ν™˜
if results:
# ν‘œμ‹œν•  μ—΄ μˆœμ„œ μ§€μ •
columns_order = ['company_name', 'category', 'company_info', 'product_name', 'bm25_score', 'vector_score', 'final_score', 'description']
df_results = pd.DataFrame(results)[columns_order]
# μ—΄ 이름 ν•œκΈ€ν™”
df_results.columns = ['νšŒμ‚¬λͺ…', 'μΉ΄ν…Œκ³ λ¦¬', 'νšŒμ‚¬ μ„€λͺ…', 'μ œν’ˆλͺ…', 'ν‚€μ›Œλ“œ 점수', '벑터 점수', 'μ΅œμ’… 점수', 'μ„€λͺ…']
html_table = df_results.to_html(
classes=['table', 'table-striped'],
escape=False,
index=False,
float_format=lambda x: '{:.3f}'.format(x) # μ†Œμˆ˜μ  3μžλ¦¬κΉŒμ§€ ν‘œμ‹œ
)
else:
html_table = "<p>검색 κ²°κ³Όκ°€ μ—†μŠ΅λ‹ˆλ‹€.</p>"
# 상세 κ²°κ³Ό ν…μŠ€νŠΈ 생성
detailed_results = []
for i, result in enumerate(results, 1):
detailed_results.append(f"""
=== 검색결과 #{i} ===
νšŒμ‚¬λͺ…: {result['company_name']}
μΉ΄ν…Œκ³ λ¦¬: {result['category']}
νšŒμ‚¬ μ„€λͺ…: {result['company_info']}
μ œν’ˆλͺ…: {result['product_name']}
ν‚€μ›Œλ“œ 점수: {result['bm25_score']:.3f}
벑터 점수: {result['vector_score']:.3f}
μ΅œμ’… 점수: {result['final_score']:.3f}
μ„€λͺ…: {result['description']}
""")
detailed_text = "\n".join(detailed_results) if detailed_results else "검색 κ²°κ³Όκ°€ μ—†μŠ΅λ‹ˆλ‹€."
return html_table, detailed_text
# Gradio μΈν„°νŽ˜μ΄μŠ€ μ •μ˜
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown("""
# πŸ” μ½”μ—‘μŠ€ λΆ€μŠ€ μΆ”μ²œ μ‹œμŠ€ν…œ
ν•˜μ΄λΈŒλ¦¬λ“œ 방식을 μ΄μš©ν•œ κΈ°μ—… 및 μ œν’ˆ 검색/μΆ”μ²œ μ‹œμŠ€ν…œμž…λ‹ˆλ‹€.
""")
with gr.Row():
with gr.Column(scale=4):
query_input = gr.Textbox(
label="검색어λ₯Ό μž…λ ₯ν•˜μ„Έμš”",
placeholder="예: AI 기술 νšŒμ‚¬, μ„Όμ„œ, μžλ™ν™” λ“±",
)
with gr.Column(scale=1):
top_k = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="검색 κ²°κ³Ό 수",
)
with gr.Row():
bm25_weight = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.1,
label="ν‚€μ›Œλ“œ 검색 κ°€μ€‘μΉ˜",
)
with gr.Row():
search_button = gr.Button("검색", variant="primary")
with gr.Row():
with gr.Column():
results_table = gr.HTML(label="검색 κ²°κ³Ό ν…Œμ΄λΈ”")
with gr.Column():
results_text = gr.Textbox(
label="상세 κ²°κ³Ό",
show_label=True,
interactive=False,
lines=10
)
# 이벀트 ν•Έλ“€λŸ¬ μ—°κ²°
search_button.click(
fn=search_products,
inputs=[query_input, top_k, bm25_weight],
outputs=[results_table, results_text],
)
gr.Markdown("""
### μ‚¬μš© 방법
1. 검색어 μž…λ ₯: 찾고자 ν•˜λŠ” κΈ°μ—…, μ œν’ˆ, 기술 λ“±μ˜ ν‚€μ›Œλ“œλ₯Ό μž…λ ₯ν•˜μ„Έμš”
2. 검색 κ²°κ³Ό 수 μ‘°μ •: μ›ν•˜λŠ” κ²°κ³Ό 수λ₯Ό μ„ νƒν•˜μ„Έμš”
3. κ°€μ€‘μΉ˜ μ‘°μ •: ν‚€μ›Œλ“œ λ§€μΉ­κ³Ό 의미적 μœ μ‚¬λ„ κ°„μ˜ κ°€μ€‘μΉ˜λ₯Ό μ‘°μ ˆν•˜μ„Έμš”
### 점수 μ„€λͺ…
- ν‚€μ›Œλ“œ 점수: Kiwi ν† ν¬λ‚˜μ΄μ €λ₯Ό μ‚¬μš©ν•œ ν‚€μ›Œλ“œ 기반 λ§€μΉ­ 점수 (0~1)
- 벑터 점수: 의미적 μœ μ‚¬λ„ 점수 (0~1)
- μ΅œμ’… 점수: ν‚€μ›Œλ“œ μ μˆ˜μ™€ 벑터 점수의 가쀑 평균
""")
return demo
def main():
demo = create_gradio_interface()
demo.launch(share=True)
if __name__ == "__main__":
main()
# TODO
# OCR λ”₯λŸ¬λ‹ vs OCR 처리
# ν† ν¬λ‚˜μ΄μ € 처리 κ²°κ³Ό ν…ŒμŠ€νŠΈ
# ν’ˆμ‚¬ νƒœκΉ… κ²°κ³Ό 확인