traning / searchKeybert.py
aikobay's picture
Update searchKeybert.py
af774df verified
raw
history blame contribute delete
5.2 kB
import os
import torch
import pandas as pd
import logging
import re
import faiss
import numpy as np
import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from datasets import load_dataset
from huggingface_hub import login
from sentence_transformers import SentenceTransformer
from joblib import Parallel, delayed
from tqdm import tqdm
from keybert import KeyBERT # βœ… KeyBERT μΆ”κ°€
# βœ… 둜그 μ„€μ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if HF_API_TOKEN:
try:
logger.info("πŸ” Hugging Face에 둜그인 쀑...")
login(token=HF_API_TOKEN)
logger.info("βœ… Hugging Face 둜그인 성곡!")
except Exception as e:
logger.error(f"❌ Hugging Face 둜그인 μ‹€νŒ¨: {e}")
sys.exit("🚫 ν”„λ‘œκ·Έλž¨μ„ μ’…λ£Œν•©λ‹ˆλ‹€. μœ νš¨ν•œ HF_API_TOKEN이 ν•„μš”ν•©λ‹ˆλ‹€.")
else:
logger.error("❌ ν™˜κ²½ λ³€μˆ˜ 'HF_API_TOKEN'이 μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
sys.exit("🚫 ν”„λ‘œκ·Έλž¨μ„ μ’…λ£Œν•©λ‹ˆλ‹€. HF_API_TOKEN ν™˜κ²½ λ³€μˆ˜λ₯Ό μ„€μ •ν•΄ μ£Όμ„Έμš”.")
# βœ… FastAPI μΈμŠ€ν„΄μŠ€ 생성
app = FastAPI(title="πŸš€ KeyBERT 기반 FAISS 검색 API", version="1.0")
# βœ… KeyBERT λͺ¨λΈ λ‘œλ“œ
embedding_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
keyword_model = KeyBERT(embedding_model) # βœ… KeyBERT 적용
logger.info("βœ… KeyBERT 기반 ν‚€μ›Œλ“œ μΆ”μΆœ λͺ¨λΈ λ‘œλ“œ μ™„λ£Œ!")
# βœ… μ§„ν–‰ 쀑인 κ²½λ§€ μƒν’ˆ 데이터 λ‘œλ“œ
def load_huggingface_jsonl(dataset_name, split="train"):
try:
repo_id = f"aikobay/{dataset_name}"
dataset = load_dataset(repo_id, split=split)
df = dataset.to_pandas().dropna()
return df
except Exception as e:
logger.error(f"❌ 데이터 λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
return pd.DataFrame()
try:
active_sale_items = load_huggingface_jsonl("initial_saleitem_dataset")
logger.info(f"βœ… μ§„ν–‰ 쀑인 κ²½λ§€ μƒν’ˆ 데이터 λ‘œλ“œ μ™„λ£Œ! 총 {len(active_sale_items)}개 μƒν’ˆ")
except Exception as e:
logger.error(f"❌ μƒν’ˆ 데이터 λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
active_sale_items = pd.DataFrame()
# βœ… FAISS 인덱슀 μ΄ˆκΈ°ν™”
faiss_index = faiss.IndexFlatL2(384)
indexed_items = []
# βœ… FAISS 인덱슀 ꡬ좕
def rebuild_faiss_index():
global faiss_index, indexed_items, active_sale_items
logger.info("πŸ”„ μƒˆλ‘œμš΄ sale_item λ°μ΄ν„°λ‘œ FAISS 인덱슀λ₯Ό μž¬κ΅¬μΆ•ν•©λ‹ˆλ‹€...")
active_sale_items = load_huggingface_jsonl("initial_saleitem_dataset")
item_names = active_sale_items["ITEMNAME"].tolist()
indexed_items = item_names
logger.info(f"πŸ”Ή 총 {len(item_names)}개 μƒν’ˆ 벑터화 μ‹œμž‘...")
item_vectors = embedding_model.encode(item_names, convert_to_numpy=True).astype("float32")
faiss_index = faiss.IndexFlatL2(item_vectors.shape[1])
faiss_index.add(item_vectors)
logger.info(f"βœ… FAISS μΈλ±μŠ€κ°€ {len(indexed_items)}개 μƒν’ˆμœΌλ‘œ μƒˆλ‘­κ²Œ κ΅¬μΆ•λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
# βœ… KeyBERT 기반 ν‚€μ›Œλ“œ μΆ”μΆœ ν•¨μˆ˜
def generate_similar_keywords(query: str, num_keywords: int = 5):
"""KeyBERT λͺ¨λΈμ„ μ΄μš©ν•΄ κ²€μƒ‰μ–΄μ˜ μœ μ‚¬ ν‚€μ›Œλ“œ μΆ”μΆœ"""
try:
keywords = keyword_model.extract_keywords(query, keyphrase_ngram_range=(1,2), stop_words=None, top_n=num_keywords)
keywords = [kw[0] for kw in keywords]
logger.info(f"πŸ” μƒμ„±λœ μœ μ‚¬ ν‚€μ›Œλ“œ: {keywords}")
return keywords
except Exception as e:
logger.error(f"❌ KeyBERT ν‚€μ›Œλ“œ μΆ”μΆœ 쀑 였λ₯˜ λ°œμƒ: {e}")
return [query]
# βœ… FAISS 검색 ν•¨μˆ˜ (μœ μ‚¬ ν‚€μ›Œλ“œ 기반 검색)
def search_faiss_with_keywords(query: str, top_k: int = 5):
start_time = time.time()
keywords = generate_similar_keywords(query)
keyword_vectors = embedding_model.encode(keywords, convert_to_numpy=True).astype("float32")
all_results = []
for vec in keyword_vectors:
_, indices = faiss_index.search(np.array([vec]), top_k)
all_results.extend(indices[0])
unique_results = list(set(all_results))
recommendations = [indexed_items[i] for i in unique_results]
end_time = time.time()
logger.info(f"πŸ” 검색 μˆ˜ν–‰ μ™„λ£Œ! κ±Έλ¦° μ‹œκ°„: {end_time - start_time:.4f}초")
return recommendations
# βœ… API μš”μ²­ λͺ¨λΈ
class RecommendRequest(BaseModel):
search_query: str
top_k: int = 5
# βœ… μΆ”μ²œ API μ—”λ“œν¬μΈνŠΈ
@app.post("/api/recommend")
async def recommend(request: RecommendRequest):
try:
recommendations = search_faiss_with_keywords(request.search_query, request.top_k)
return {"query": request.search_query, "recommendations": recommendations}
except Exception as e:
raise HTTPException(status_code=500, detail=f"μΆ”μ²œ 였λ₯˜: {str(e)}")
# βœ… FAISS 인덱슀 κ°±μ‹  API
@app.post("/api/update_index")
async def update_index():
rebuild_faiss_index()
return {"message": "βœ… FAISS 인덱슀 μ—…λ°μ΄νŠΈ μ™„λ£Œ!"}
# βœ… FastAPI μ‹€ν–‰
if __name__ == "__main__":
rebuild_faiss_index()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)