|
import os |
|
import torch |
|
import pandas as pd |
|
import logging |
|
import re |
|
import faiss |
|
import numpy as np |
|
import time |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from datasets import load_dataset |
|
from huggingface_hub import login |
|
from sentence_transformers import SentenceTransformer |
|
from joblib import Parallel, delayed |
|
from tqdm import tqdm |
|
from keybert import KeyBERT |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
|
if HF_API_TOKEN: |
|
try: |
|
logger.info("π Hugging Faceμ λ‘κ·ΈμΈ μ€...") |
|
login(token=HF_API_TOKEN) |
|
logger.info("β
Hugging Face λ‘κ·ΈμΈ μ±κ³΅!") |
|
except Exception as e: |
|
logger.error(f"β Hugging Face λ‘κ·ΈμΈ μ€ν¨: {e}") |
|
sys.exit("π« νλ‘κ·Έλ¨μ μ’
λ£ν©λλ€. μ ν¨ν HF_API_TOKENμ΄ νμν©λλ€.") |
|
else: |
|
logger.error("β νκ²½ λ³μ 'HF_API_TOKEN'μ΄ μ€μ λμ§ μμμ΅λλ€.") |
|
sys.exit("π« νλ‘κ·Έλ¨μ μ’
λ£ν©λλ€. HF_API_TOKEN νκ²½ λ³μλ₯Ό μ€μ ν΄ μ£ΌμΈμ.") |
|
|
|
|
|
app = FastAPI(title="π KeyBERT κΈ°λ° FAISS κ²μ API", version="1.0") |
|
|
|
|
|
embedding_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2") |
|
keyword_model = KeyBERT(embedding_model) |
|
logger.info("β
KeyBERT κΈ°λ° ν€μλ μΆμΆ λͺ¨λΈ λ‘λ μλ£!") |
|
|
|
|
|
def load_huggingface_jsonl(dataset_name, split="train"): |
|
try: |
|
repo_id = f"aikobay/{dataset_name}" |
|
dataset = load_dataset(repo_id, split=split) |
|
df = dataset.to_pandas().dropna() |
|
return df |
|
except Exception as e: |
|
logger.error(f"β λ°μ΄ν° λ‘λ μ€ μ€λ₯ λ°μ: {e}") |
|
return pd.DataFrame() |
|
|
|
try: |
|
active_sale_items = load_huggingface_jsonl("initial_saleitem_dataset") |
|
logger.info(f"β
μ§ν μ€μΈ κ²½λ§€ μν λ°μ΄ν° λ‘λ μλ£! μ΄ {len(active_sale_items)}κ° μν") |
|
except Exception as e: |
|
logger.error(f"β μν λ°μ΄ν° λ‘λ μ€ μ€λ₯ λ°μ: {e}") |
|
active_sale_items = pd.DataFrame() |
|
|
|
|
|
faiss_index = faiss.IndexFlatL2(384) |
|
indexed_items = [] |
|
|
|
|
|
def rebuild_faiss_index(): |
|
global faiss_index, indexed_items, active_sale_items |
|
logger.info("π μλ‘μ΄ sale_item λ°μ΄ν°λ‘ FAISS μΈλ±μ€λ₯Ό μ¬κ΅¬μΆν©λλ€...") |
|
|
|
active_sale_items = load_huggingface_jsonl("initial_saleitem_dataset") |
|
item_names = active_sale_items["ITEMNAME"].tolist() |
|
indexed_items = item_names |
|
|
|
logger.info(f"πΉ μ΄ {len(item_names)}κ° μν 벑ν°ν μμ...") |
|
item_vectors = embedding_model.encode(item_names, convert_to_numpy=True).astype("float32") |
|
|
|
faiss_index = faiss.IndexFlatL2(item_vectors.shape[1]) |
|
faiss_index.add(item_vectors) |
|
logger.info(f"β
FAISS μΈλ±μ€κ° {len(indexed_items)}κ° μνμΌλ‘ μλ‘κ² κ΅¬μΆλμμ΅λλ€.") |
|
|
|
|
|
def generate_similar_keywords(query: str, num_keywords: int = 5): |
|
"""KeyBERT λͺ¨λΈμ μ΄μ©ν΄ κ²μμ΄μ μ μ¬ ν€μλ μΆμΆ""" |
|
try: |
|
keywords = keyword_model.extract_keywords(query, keyphrase_ngram_range=(1,2), stop_words=None, top_n=num_keywords) |
|
keywords = [kw[0] for kw in keywords] |
|
logger.info(f"π μμ±λ μ μ¬ ν€μλ: {keywords}") |
|
return keywords |
|
except Exception as e: |
|
logger.error(f"β KeyBERT ν€μλ μΆμΆ μ€ μ€λ₯ λ°μ: {e}") |
|
return [query] |
|
|
|
|
|
def search_faiss_with_keywords(query: str, top_k: int = 5): |
|
start_time = time.time() |
|
keywords = generate_similar_keywords(query) |
|
keyword_vectors = embedding_model.encode(keywords, convert_to_numpy=True).astype("float32") |
|
|
|
all_results = [] |
|
for vec in keyword_vectors: |
|
_, indices = faiss_index.search(np.array([vec]), top_k) |
|
all_results.extend(indices[0]) |
|
|
|
unique_results = list(set(all_results)) |
|
recommendations = [indexed_items[i] for i in unique_results] |
|
|
|
end_time = time.time() |
|
logger.info(f"π κ²μ μν μλ£! κ±Έλ¦° μκ°: {end_time - start_time:.4f}μ΄") |
|
return recommendations |
|
|
|
|
|
class RecommendRequest(BaseModel): |
|
search_query: str |
|
top_k: int = 5 |
|
|
|
|
|
@app.post("/api/recommend") |
|
async def recommend(request: RecommendRequest): |
|
try: |
|
recommendations = search_faiss_with_keywords(request.search_query, request.top_k) |
|
return {"query": request.search_query, "recommendations": recommendations} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"μΆμ² μ€λ₯: {str(e)}") |
|
|
|
|
|
@app.post("/api/update_index") |
|
async def update_index(): |
|
rebuild_faiss_index() |
|
return {"message": "β
FAISS μΈλ±μ€ μ
λ°μ΄νΈ μλ£!"} |
|
|
|
|
|
if __name__ == "__main__": |
|
rebuild_faiss_index() |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|