|
import os |
|
import torch |
|
import pandas as pd |
|
import logging |
|
import faiss |
|
import numpy as np |
|
import time |
|
import gensim |
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
from pydantic import BaseModel |
|
from datasets import load_dataset |
|
from huggingface_hub import login, hf_hub_download, HfApi, create_repo |
|
from keybert import KeyBERT |
|
from sentence_transformers import SentenceTransformer |
|
from joblib import Parallel, delayed |
|
from tqdm import tqdm |
|
import tempfile |
|
import re |
|
import sys |
|
import asyncio |
|
import multiprocessing |
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
thread_pool = ThreadPoolExecutor(max_workers=min(64, os.cpu_count() * 4)) |
|
|
|
|
|
|
|
app = FastAPI(title="๐ KeyBERT + Word2Vec ๊ธฐ๋ฐ FAISS ๊ฒ์ API", version="1.2") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"๐ ์คํ ๋๋ฐ์ด์ค: {device.upper()}") |
|
|
|
|
|
HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
if multiprocessing.current_process().name == "MainProcess": |
|
if HF_API_TOKEN and HF_API_TOKEN.startswith("hf_"): |
|
logger.info("๐ Hugging Face API ๋ก๊ทธ์ธ ์ค...") |
|
login(token=HF_API_TOKEN) |
|
else: |
|
logger.warning("โ ๏ธ HF_API_TOKEN์ด ์ค์ ๋์ง ์์๊ฑฐ๋ ํ ํฐ ํ์์ด ์๋๋๋ค.") |
|
|
|
|
|
|
|
word2vec_model = None |
|
try: |
|
logger.info("๐ Word2Vec ๋ชจ๋ธ ๋ก๋ ์ค...") |
|
MODEL_REPO = "aikobay/item-model" |
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="item_vectors.bin", repo_type="dataset") |
|
word2vec_model = gensim.models.KeyedVectors.load_word2vec_format(model_path, binary=True) |
|
logger.info(f"โ
Word2Vec ๋ชจ๋ธ ๋ก๋ ์๋ฃ! ๋จ์ด ์: {len(word2vec_model.key_to_index)}") |
|
except Exception as e: |
|
logger.error(f"โ Word2Vec ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") |
|
|
|
|
|
logger.info("๐ KeyBERT ๋ชจ๋ธ ๋ก๋ ์ค...") |
|
kw_model = KeyBERT("paraphrase-multilingual-MiniLM-L12-v2") |
|
|
|
logger.info("โ
KeyBERT ๋ชจ๋ธ ๋ก๋ ์๋ฃ!") |
|
|
|
|
|
embedding_model = None |
|
|
|
def get_embedding_model(): |
|
global embedding_model |
|
if embedding_model is None: |
|
try: |
|
logger.info("๐ ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ ์ค (lazy-load)...") |
|
embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask") |
|
logger.info("โ
์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ ์๋ฃ!") |
|
except Exception as e: |
|
logger.warning(f"โ ๏ธ ํ๊ตญ์ด ํนํ ๋ชจ๋ธ ๋ก๋ ์คํจ, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ: {e}") |
|
embedding_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2") |
|
return embedding_model |
|
|
|
|
|
async def load_huggingface_jsonl(dataset_name, split="train"): |
|
"""Hugging Face Hub์์ ๋ฐ์ดํฐ์
๋น๋๊ธฐ ๋ก๋""" |
|
try: |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
def _load_dataset(): |
|
repo_id = f"aikobay/{dataset_name}" |
|
dataset = load_dataset(repo_id, split=split) |
|
return dataset.to_pandas().dropna() |
|
|
|
|
|
df = await loop.run_in_executor(thread_pool, _load_dataset) |
|
return df |
|
except Exception as e: |
|
logger.error(f"โ ๋ฐ์ดํฐ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") |
|
return pd.DataFrame() |
|
|
|
|
|
active_sale_items = None |
|
try: |
|
|
|
loop = asyncio.new_event_loop() |
|
active_sale_items = loop.run_until_complete(load_huggingface_jsonl("initial_saleitem_dataset")) |
|
loop.close() |
|
|
|
if active_sale_items.empty: |
|
logger.error("โ ๋ฐ์ดํฐ์
์ด ๋น์ด ์์ต๋๋ค. ํ๋ก๊ทธ๋จ์ ์ข
๋ฃํฉ๋๋ค.") |
|
exit(1) |
|
logger.info(f"โ
๊ฒฝ๋งค ์ํ ๋ฐ์ดํฐ ๋ก๋ ์๋ฃ! ์ด {len(active_sale_items)}๊ฐ ์ํ") |
|
except Exception as e: |
|
logger.error(f"โ ์ํ ๋ฐ์ดํฐ ๋ก๋ ์คํจ: {e}") |
|
exit(1) |
|
|
|
|
|
faiss_index = None |
|
indexed_items = [] |
|
|
|
|
|
async def encode_texts_parallel(texts, batch_size=1024): |
|
"""GPU ํ์ฉ + ๋ฐฐ์น ์ฌ์ด์ฆ ์ต์ ํ ๋ฒกํฐํ (๋๊ท๋ชจ ์ฑ๋ฅ ํฅ์)""" |
|
if not texts: |
|
return np.array([]).astype("float32") |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
def _encode_efficiently(): |
|
model = get_embedding_model() |
|
return model.encode( |
|
texts, |
|
batch_size=batch_size, |
|
convert_to_numpy=True, |
|
show_progress_bar=False, |
|
device=device |
|
) |
|
|
|
|
|
embeddings = await loop.run_in_executor(thread_pool, _encode_efficiently) |
|
return embeddings.astype("float32") |
|
|
|
|
|
|
|
async def save_faiss_index(): |
|
"""FAISS ์ธ๋ฑ์ค๋ฅผ Hugging Face Hub์ ์ ์ฅ (๋น๋๊ธฐ ์ง์)""" |
|
global faiss_index, indexed_items |
|
|
|
if faiss_index is None or not indexed_items: |
|
logger.error("โ ์ ์ฅํ FAISS ์ธ๋ฑ์ค๊ฐ ์์ต๋๋ค.") |
|
return False |
|
|
|
try: |
|
|
|
repo_id = os.getenv("HF_INDEX_REPO", "aikobay/saleitem_faiss_index") |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
def _save_index(): |
|
|
|
api = HfApi() |
|
|
|
|
|
try: |
|
api.repo_info(repo_id=repo_id, repo_type="dataset") |
|
logger.info(f"โ
๊ธฐ์กด ๋ ํฌ์งํ ๋ฆฌ ์ฌ์ฉ: {repo_id}") |
|
except Exception: |
|
logger.info(f"๐ ๋ ํฌ์งํ ๋ฆฌ๊ฐ ์กด์ฌํ์ง ์์ ์๋ก ์์ฑํฉ๋๋ค: {repo_id}") |
|
create_repo( |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
private=True, |
|
exist_ok=True |
|
) |
|
logger.info(f"โ
๋ ํฌ์งํ ๋ฆฌ ์์ฑ ์๋ฃ: {repo_id}") |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
index_path = os.path.join(temp_dir, "faiss_index.bin") |
|
items_path = os.path.join(temp_dir, "indexed_items.txt") |
|
|
|
|
|
faiss.write_index(faiss_index, index_path) |
|
|
|
|
|
with open(items_path, "w", encoding="utf-8") as f: |
|
f.write("\n".join(indexed_items)) |
|
|
|
|
|
readme_path = os.path.join(temp_dir, "README.md") |
|
with open(readme_path, "w", encoding="utf-8") as f: |
|
f.write(f"""# FAISS ์ธ๋ฑ์ค ์ ์ฅ์ |
|
์ด ์ ์ฅ์๋ ์ํ ๊ฒ์์ ์ํ FAISS ์ธ๋ฑ์ค์ ๊ด๋ จ ๋ฐ์ดํฐ๋ฅผ ํฌํจํ๊ณ ์์ต๋๋ค. |
|
- ์ต์ข
์
๋ฐ์ดํธ: {pd.Timestamp.now()} |
|
- ์ธ๋ฑ์ค ํญ๋ชฉ ์: {len(indexed_items)} |
|
- ๋ชจ๋ธ: KeyBERT + Word2Vec |
|
์ด ์ ์ฅ์๋ 'aikobay/initial_saleitem_dataset'์ ์ํ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์์ฑ๋ ๋ฒกํฐ ์ธ๋ฑ์ค๋ฅผ ์ ์ฅํ๊ธฐ ์ํด ์๋ ์์ฑ๋์์ต๋๋ค. |
|
""") |
|
|
|
|
|
for file_path, file_name in [ |
|
(index_path, "faiss_index.bin"), |
|
(items_path, "indexed_items.txt"), |
|
(readme_path, "README.md") |
|
]: |
|
api.upload_file( |
|
path_or_fileobj=file_path, |
|
path_in_repo=file_name, |
|
repo_id=repo_id, |
|
repo_type="dataset" |
|
) |
|
|
|
logger.info(f"โ
FAISS ์ธ๋ฑ์ค๊ฐ Hugging Face Hub์ ์ ์ฅ๋์์ต๋๋ค. ๋ ํฌ: {repo_id}") |
|
return True |
|
|
|
|
|
result = await loop.run_in_executor(thread_pool, _save_index) |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"โ FAISS ์ธ๋ฑ์ค Hub ์ ์ฅ ์ค ์ค๋ฅ ๋ฐ์: {e}") |
|
|
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
|
|
def _local_backup(): |
|
local_path = os.path.join(os.getcwd(), "faiss_index.bin") |
|
faiss.write_index(faiss_index, local_path) |
|
with open("indexed_items.txt", "w", encoding="utf-8") as f: |
|
f.write("\n".join(indexed_items)) |
|
logger.info(f"โ
FAISS ์ธ๋ฑ์ค๊ฐ ๋ก์ปฌ์ ๋ฐฑ์
์ ์ฅ๋์์ต๋๋ค: {local_path}") |
|
return True |
|
|
|
result = await loop.run_in_executor(thread_pool, _local_backup) |
|
return result |
|
except Exception as local_err: |
|
logger.error(f"โ ๋ก์ปฌ ๋ฐฑ์
์ ์ฅ๋ ์คํจ: {local_err}") |
|
return False |
|
|
|
|
|
async def load_faiss_index(): |
|
"""Hugging Face Hub์์ FAISS ์ธ๋ฑ์ค๋ฅผ ๋ก๋ (๋น๋๊ธฐ ์ง์)""" |
|
global faiss_index, indexed_items |
|
|
|
|
|
repo_id = os.getenv("HF_INDEX_REPO", "aikobay/saleitem_faiss_index") |
|
|
|
try: |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
def _load_index(): |
|
|
|
api = HfApi() |
|
try: |
|
api.repo_info(repo_id=repo_id, repo_type="dataset") |
|
logger.info(f"โ
FAISS ์ธ๋ฑ์ค ๋ ํฌ์งํ ๋ฆฌ ํ์ธ: {repo_id}") |
|
except Exception as repo_err: |
|
logger.warning(f"โ ๏ธ ๋ ํฌ์งํ ๋ฆฌ๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค: {repo_err}") |
|
raise FileNotFoundError("Hub ๋ ํฌ์งํ ๋ฆฌ๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค") |
|
|
|
|
|
index_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="faiss_index.bin", |
|
repo_type="dataset" |
|
) |
|
|
|
items_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="indexed_items.txt", |
|
repo_type="dataset" |
|
) |
|
|
|
|
|
loaded_index = faiss.read_index(index_path) |
|
with open(items_path, "r", encoding="utf-8") as f: |
|
loaded_items = f.read().splitlines() |
|
|
|
return loaded_index, loaded_items |
|
|
|
|
|
loaded_index, loaded_items = await loop.run_in_executor(thread_pool, _load_index) |
|
|
|
|
|
faiss_index = loaded_index |
|
indexed_items = loaded_items |
|
|
|
logger.info(f"โ
FAISS ์ธ๋ฑ์ค๊ฐ Hub์์ ๋ก๋๋์์ต๋๋ค. ์ด {len(indexed_items)}๊ฐ ์ํ") |
|
return True |
|
|
|
except Exception as e: |
|
logger.warning(f"โ ๏ธ Hub์์ FAISS ์ธ๋ฑ์ค ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") |
|
|
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
|
|
def _load_local(): |
|
local_index_path = "faiss_index.bin" |
|
local_items_path = "indexed_items.txt" |
|
|
|
if os.path.exists(local_index_path) and os.path.exists(local_items_path): |
|
loaded_index = faiss.read_index(local_index_path) |
|
with open(local_items_path, "r", encoding="utf-8") as f: |
|
loaded_items = f.read().splitlines() |
|
return loaded_index, loaded_items |
|
else: |
|
logger.warning("โ ๏ธ ๋ก์ปฌ FAISS ์ธ๋ฑ์ค ํ์ผ์ด ์กด์ฌํ์ง ์์ต๋๋ค.") |
|
return None, None |
|
|
|
|
|
result = await loop.run_in_executor(thread_pool, _load_local) |
|
|
|
if result[0] is not None: |
|
faiss_index, indexed_items = result |
|
logger.info(f"โ
๋ก์ปฌ FAISS ์ธ๋ฑ์ค ๋ก๋ ์ฑ๊ณต. ์ด {len(indexed_items)}๊ฐ ์ํ") |
|
return True |
|
else: |
|
return False |
|
|
|
except Exception as local_err: |
|
logger.error(f"โ ๋ก์ปฌ FAISS ์ธ๋ฑ์ค ๋ก๋ ์ค ์ค๋ฅ: {local_err}") |
|
return False |
|
|
|
|
|
async def rebuild_faiss_index(): |
|
"""FAISS ์ธ๋ฑ์ค๋ฅผ IVF ๊ธฐ๋ฐ์ผ๋ก ์๋กญ๊ฒ ๊ตฌ์ถ (์๋ ์ต์ ํ)""" |
|
global faiss_index, indexed_items, active_sale_items |
|
|
|
logger.info("๐ FAISS ์ธ๋ฑ์ค๋ฅผ ๊ณ ์ IVF ๊ธฐ๋ฐ์ผ๋ก ์ฌ๊ตฌ์ถ ์ค...") |
|
|
|
|
|
active_sale_items = await load_huggingface_jsonl("initial_saleitem_dataset") |
|
if active_sale_items.empty: |
|
logger.error("โ ์ํ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค.") |
|
raise RuntimeError("์ํ ๋ฐ์ดํฐ ๋ก๋ ์คํจ") |
|
|
|
|
|
item_names = active_sale_items["ITEMNAME"].tolist() |
|
indexed_items = item_names |
|
|
|
|
|
total_items = len(item_names) |
|
logger.info(f"๐น ์ด {total_items}๊ฐ ์ํ ๊ณ ์ ๋ฒกํฐํ ์์...") |
|
|
|
|
|
item_vectors = await encode_texts_parallel(item_names, batch_size=1024) |
|
|
|
|
|
norms = np.linalg.norm(item_vectors, axis=1, keepdims=True) |
|
norms[norms == 0] = 1.0 |
|
normalized_vectors = item_vectors / norms |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
def _build_ivf_index(): |
|
dimension = item_vectors.shape[1] |
|
|
|
nlist = int(np.sqrt(total_items) * 4) |
|
nlist = max(32, min(nlist, 1024)) |
|
|
|
|
|
M = min(64, dimension // 2) |
|
nbits = 8 |
|
|
|
|
|
if total_items > 10000: |
|
|
|
quantizer = faiss.IndexFlatIP(dimension) |
|
index = faiss.IndexIVFPQ(quantizer, dimension, nlist, M, nbits) |
|
else: |
|
|
|
quantizer = faiss.IndexFlatIP(dimension) |
|
index = faiss.IndexIVFFlat(quantizer, dimension, nlist) |
|
|
|
|
|
index.train(normalized_vectors) |
|
index.add(normalized_vectors) |
|
|
|
|
|
|
|
index.nprobe = min(32, nlist // 4) |
|
|
|
logger.info(f"โ
IVF ์ธ๋ฑ์ค ๊ตฌ์ถ ์๋ฃ: clusters={nlist}, nprobe={index.nprobe}") |
|
return index |
|
|
|
|
|
faiss_index = await loop.run_in_executor(thread_pool, _build_ivf_index) |
|
|
|
logger.info(f"โ
๊ณ ์ FAISS ์ธ๋ฑ์ค ๊ตฌ์ถ ์๋ฃ! ์ด {len(indexed_items)}๊ฐ ํญ๋ชฉ") |
|
|
|
|
|
await save_faiss_index() |
|
return True |
|
|
|
|
|
|
|
async def check_faiss_index(): |
|
"""FAISS ์ธ๋ฑ์ค๊ฐ ์กด์ฌํ๋์ง ํ์ธํ๊ณ ์์ผ๋ฉด ๊ตฌ์ถ (๋น๋๊ธฐ ์ง์)""" |
|
global faiss_index |
|
|
|
if faiss_index is None: |
|
|
|
if not await load_faiss_index(): |
|
|
|
logger.warning("โ ๏ธ ์ ์ฅ๋ ์ธ๋ฑ์ค๊ฐ ์์ด ์๋ก ๊ตฌ์ถํฉ๋๋ค.") |
|
await rebuild_faiss_index() |
|
|
|
|
|
if faiss_index is None: |
|
raise RuntimeError("FAISS ์ธ๋ฑ์ค ์ด๊ธฐํ์ ์คํจํ์ต๋๋ค.") |
|
|
|
|
|
async def extract_keywords(query: str, top_n: int = 2): |
|
"""KeyBERT ์ต์ ํ ํค์๋ ์ถ์ถ (์ฑ๋ฅ ์ค์ฌ)""" |
|
|
|
if len(query) <= 3: |
|
return [query] |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
def _optimized_extract(): |
|
|
|
return kw_model.extract_keywords( |
|
query, |
|
keyphrase_ngram_range=(1, 1), |
|
stop_words=["์ด", "๊ทธ", "์ ", "์", "๋ฅผ", "์", "์์", "์", "๋"], |
|
use_mmr=True, |
|
diversity=0.5, |
|
top_n=top_n |
|
) |
|
|
|
try: |
|
keywords = await loop.run_in_executor(thread_pool, _optimized_extract) |
|
|
|
filtered = [(k, s) for k, s in keywords if s > 0.2] |
|
return [k[0] for k in filtered] |
|
except Exception as e: |
|
logger.error(f"โ ํค์๋ ์ถ์ถ ์ค๋ฅ: {str(e)}") |
|
|
|
return query.split()[:2] |
|
|
|
|
|
|
|
async def expand_keywords_with_word2vec(keywords: list, max_new=2): |
|
"""Word2Vec ํค์๋ ํ์ฅ ์ต์ ํ""" |
|
global word2vec_model |
|
|
|
if word2vec_model is None or not keywords: |
|
return keywords |
|
|
|
|
|
expanded = set(keywords) |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
def _expand_keywords(): |
|
for keyword in keywords: |
|
|
|
if keyword in word2vec_model: |
|
|
|
similar_words = word2vec_model.most_similar(keyword, topn=max_new) |
|
for word, score in similar_words: |
|
if score > 0.7: |
|
expanded.add(word) |
|
|
|
elif len(keyword.split()) > 1: |
|
word = keyword.split()[0] |
|
if word in word2vec_model and len(word) > 1: |
|
similar = word2vec_model.most_similar(word, topn=1) |
|
if similar and similar[0][1] > 0.8: |
|
expanded.add(similar[0][0]) |
|
|
|
|
|
result = list(expanded) |
|
|
|
if len(result) > 5: |
|
return keywords + result[len(keywords):5] |
|
return result |
|
|
|
try: |
|
|
|
expanded_keywords = await loop.run_in_executor(thread_pool, _expand_keywords) |
|
return expanded_keywords |
|
except Exception as e: |
|
logger.error(f"โ Word2Vec ํ์ฅ ์ค๋ฅ: {str(e)}") |
|
return keywords |
|
|
|
|
|
|
|
async def search_faiss_with_keywords(query: str, top_k: int = 5, keywords=None, expanded_keywords=None): |
|
"""๊ณ ์ ํค์๋ ๊ธฐ๋ฐ FAISS ๊ฒ์ ์ํ""" |
|
global faiss_index, indexed_items |
|
|
|
|
|
if faiss_index is None: |
|
await check_faiss_index() |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
if keywords is None: |
|
keywords = await extract_keywords(query) |
|
|
|
if expanded_keywords is None: |
|
expanded_keywords = await expand_keywords_with_word2vec(keywords) |
|
|
|
|
|
search_texts = [query] + expanded_keywords |
|
|
|
|
|
all_vectors = await encode_texts_parallel(search_texts) |
|
|
|
|
|
def normalize_batch(vectors): |
|
if vectors.size == 0: |
|
return vectors |
|
norms = np.linalg.norm(vectors, axis=1, keepdims=True) |
|
norms[norms == 0] = 1.0 |
|
return vectors / norms |
|
|
|
|
|
all_vectors = await loop.run_in_executor(thread_pool, lambda: normalize_batch(all_vectors)) |
|
|
|
|
|
if len(all_vectors) > 0: |
|
query_vector = all_vectors[0:1] |
|
keyword_vectors = all_vectors[1:] if len(all_vectors) > 1 else np.array([]) |
|
else: |
|
return [] |
|
|
|
|
|
def _optimized_batch_search(): |
|
all_results = {} |
|
|
|
|
|
if query_vector.shape[0] > 0: |
|
distances, indices = faiss_index.search(query_vector, top_k * 2) |
|
|
|
for idx, dist in zip(indices[0], distances[0]): |
|
if idx < len(indexed_items): |
|
all_results[idx] = dist * 3.0 |
|
|
|
|
|
if keyword_vectors.shape[0] > 0: |
|
|
|
k_distances, k_indices = faiss_index.search(keyword_vectors, top_k) |
|
|
|
|
|
for i in range(keyword_vectors.shape[0]): |
|
for j, (idx, dist) in enumerate(zip(k_indices[i], k_distances[i])): |
|
if idx < len(indexed_items): |
|
|
|
rank_weight = 1.0 / (1 + j * 0.2) |
|
weight = 0.6 * rank_weight |
|
|
|
|
|
all_results[idx] = all_results.get(idx, 0) + dist * weight |
|
|
|
return all_results |
|
|
|
|
|
result_scores = await loop.run_in_executor(thread_pool, _optimized_batch_search) |
|
|
|
|
|
def _process_results(): |
|
|
|
filtered_items = [(idx, score) for idx, score in result_scores.items() |
|
if score >= 0.3] |
|
|
|
|
|
sorted_items = sorted(filtered_items, key=lambda x: x[1], reverse=True) |
|
|
|
|
|
recommendations = [] |
|
for idx, score in sorted_items[:top_k]: |
|
item_name = indexed_items[idx] |
|
try: |
|
|
|
mask = active_sale_items["ITEMNAME"] == item_name |
|
if mask.any(): |
|
item_seq = active_sale_items.loc[mask, "ITEMSEQ"].values[0] |
|
recommendations.append({ |
|
"ITEMSEQ": item_seq, |
|
"ITEMNAME": item_name, |
|
"score": float(score) |
|
}) |
|
except (IndexError, KeyError): |
|
continue |
|
|
|
return recommendations |
|
|
|
|
|
recommendations = await loop.run_in_executor(thread_pool, _process_results) |
|
|
|
|
|
if len(recommendations) < top_k: |
|
direct_matches = await find_direct_matches(query, |
|
top_k - len(recommendations), |
|
[r["ITEMNAME"] for r in recommendations]) |
|
if direct_matches: |
|
recommendations.extend(direct_matches) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
if elapsed > 1.0: |
|
logger.info(f"๐ ๊ฒ์ ์๋ฃ | ์์์๊ฐ: {elapsed:.2f}์ด | ๊ฒฐ๊ณผ: {len(recommendations)}๊ฐ") |
|
|
|
return recommendations[:top_k] |
|
|
|
|
|
async def find_direct_matches(query, limit=5, existing_names=None): |
|
"""์ง์ ํ
์คํธ ๋งค์นญ ๊ฒ์ (๋ถ๋ฆฌํ์ฌ ์ต์ ํ)""" |
|
loop = asyncio.get_event_loop() |
|
|
|
def _find_matches(): |
|
matches = [] |
|
query_lower = query.lower() |
|
existing = set(existing_names or []) |
|
|
|
|
|
item_dict = {} |
|
for idx, item_name in enumerate(indexed_items): |
|
if len(matches) >= limit: |
|
break |
|
|
|
if item_name in existing: |
|
continue |
|
|
|
if query_lower in item_name.lower(): |
|
item_dict[item_name] = idx |
|
|
|
|
|
if item_dict: |
|
mask = active_sale_items["ITEMNAME"].isin(item_dict.keys()) |
|
filtered_items = active_sale_items[mask] |
|
|
|
for _, row in filtered_items.iterrows(): |
|
if len(matches) >= limit: |
|
break |
|
|
|
matches.append({ |
|
"ITEMSEQ": row["ITEMSEQ"], |
|
"ITEMNAME": row["ITEMNAME"], |
|
"score": 1.0 |
|
}) |
|
|
|
return matches |
|
|
|
|
|
return await loop.run_in_executor(thread_pool, _find_matches) |
|
|
|
|
|
class RecommendRequest(BaseModel): |
|
search_query: str |
|
top_k: int = 5 |
|
use_expansion: bool = True |
|
|
|
|
|
|
|
@app.post("/api/recommend") |
|
async def recommend(request: RecommendRequest, background_tasks: BackgroundTasks): |
|
"""๊ณ ์ ์ถ์ฒ API (I/O ๋ณ๋ ฌํ + ๋ถํ์ ์์
์ ๊ฑฐ)""" |
|
try: |
|
|
|
start_time = time.time() |
|
|
|
|
|
search_query = request.search_query.strip() |
|
if not search_query: |
|
raise HTTPException(status_code=400, detail="๊ฒ์์ด๋ฅผ ์
๋ ฅํด์ฃผ์ธ์") |
|
|
|
top_k = min(max(1, request.top_k), 20) |
|
|
|
|
|
keywords, expanded_keywords = await asyncio.gather( |
|
extract_keywords(search_query), |
|
expand_keywords_with_word2vec( |
|
[search_query.split()[0]] if search_query.split() else [search_query], |
|
max_new=2 |
|
) if request.use_expansion else None |
|
) |
|
|
|
|
|
recommendations = await search_faiss_with_keywords( |
|
search_query, |
|
top_k, |
|
keywords, |
|
expanded_keywords |
|
) |
|
|
|
|
|
result = { |
|
"query": search_query, |
|
"recommendations": recommendations, |
|
"keywords": keywords if len(keywords) > 0 else None, |
|
"expanded_keywords": expanded_keywords if expanded_keywords and len(expanded_keywords) > 0 else None |
|
} |
|
|
|
|
|
elapsed = time.time() - start_time |
|
if elapsed > 1.0: |
|
logger.info(f"โฑ๏ธ API ์๋ต ์๊ฐ: {elapsed:.2f}์ด | ์ฟผ๋ฆฌ: '{search_query}'") |
|
|
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"โ ์ถ์ฒ ์ฒ๋ฆฌ ์ค๋ฅ: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"์ถ์ฒ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค") |
|
|
|
|
|
async def check_index_health(): |
|
"""์ธ๋ฑ์ค ์ํ๋ฅผ ์ฃผ๊ธฐ์ ์ผ๋ก ํ์ธํ๋ ๋ฐฑ๊ทธ๋ผ์ด๋ ํ์คํฌ""" |
|
try: |
|
|
|
if faiss_index is None: |
|
logger.warning("โ ๏ธ ๋ฐฑ๊ทธ๋ผ์ด๋ ์ฒดํฌ: FAISS ์ธ๋ฑ์ค๊ฐ ๋ก๋๋์ง ์์์ต๋๋ค.") |
|
await check_faiss_index() |
|
|
|
|
|
logger.debug("โ
์ธ๋ฑ์ค ์ํ ํ์ธ ์๋ฃ") |
|
except Exception as e: |
|
logger.error(f"โ ๋ฐฑ๊ทธ๋ผ์ด๋ ์ธ๋ฑ์ค ์ฒดํฌ ์ค ์ค๋ฅ: {str(e)}") |
|
|
|
|
|
@app.post("/api/similar_words") |
|
async def similar_words(word: str, top_k: int = 10): |
|
"""Word2Vec ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์ฌ ๋จ์ด ๊ฒ์ API (๋น๋๊ธฐ ์ง์)""" |
|
try: |
|
if word2vec_model is None: |
|
return {"error": "Word2Vec ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค."} |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
def _get_similar(): |
|
if word not in word2vec_model: |
|
return [] |
|
|
|
similar = word2vec_model.most_similar(word, topn=top_k) |
|
return [{"word": w, "similarity": float(s)} for w, s in similar] |
|
|
|
result = await loop.run_in_executor(thread_pool, _get_similar) |
|
|
|
if not result: |
|
return {"word": word, "similar_words": [], "message": "๋จ์ด๊ฐ ๋ชจ๋ธ์ ์์ต๋๋ค."} |
|
|
|
return {"word": word, "similar_words": result} |
|
except Exception as e: |
|
logger.error(f"โ ์ ์ฌ ๋จ์ด ๊ฒ์ ์ค ์ค๋ฅ: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"์ ์ฌ ๋จ์ด ๊ฒ์ ์ค๋ฅ: {str(e)}") |
|
|
|
|
|
@app.post("/api/update_index") |
|
async def update_index(background_tasks: BackgroundTasks): |
|
"""FAISS ์ธ๋ฑ์ค๋ฅผ ์๋กญ๊ฒ ๊ตฌ์ถ (๋ช
์์ ์์ฒญ ์์๋ง, ๋น๋๊ธฐ ์ฒ๋ฆฌ)""" |
|
try: |
|
|
|
background_tasks.add_task(rebuild_and_log_index) |
|
return {"message": "โ
FAISS ์ธ๋ฑ์ค ์
๋ฐ์ดํธ๊ฐ ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์์๋์์ต๋๋ค."} |
|
except Exception as e: |
|
logger.exception("โ [API] ์ธ๋ฑ์ค ์
๋ฐ์ดํธ ์ฒ๋ฆฌ ์ค ์์ธ ๋ฐ์") |
|
raise HTTPException(status_code=500, detail=f"์ธ๋ฑ์ค ์
๋ฐ์ดํธ ์คํจ: {str(e)}") |
|
|
|
|
|
async def rebuild_and_log_index(): |
|
"""๋ฐฑ๊ทธ๋ผ์ด๋์์ ์ธ๋ฑ์ค๋ฅผ ์ฌ๊ตฌ์ถํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ๋ก๊น
""" |
|
try: |
|
logger.info("๐ ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์ธ๋ฑ์ค ์ฌ๊ตฌ์ถ ์์") |
|
start_time = time.time() |
|
await rebuild_faiss_index() |
|
elapsed = time.time() - start_time |
|
logger.info(f"โ
๋ฐฑ๊ทธ๋ผ์ด๋ ์ธ๋ฑ์ค ์ฌ๊ตฌ์ถ ์๋ฃ! ์์ ์๊ฐ: {elapsed:.2f}์ด") |
|
except Exception as e: |
|
logger.error(f"โ ๋ฐฑ๊ทธ๋ผ์ด๋ ์ธ๋ฑ์ค ์ฌ๊ตฌ์ถ ์ค ์ค๋ฅ: {str(e)}") |
|
|
|
|
|
@app.get("/api/debug_index") |
|
async def debug_index(query: str, top_k: int = 20): |
|
"""์ธ๋ฑ์ค ๋๋ฒ๊น
์ ์ํ API (๋น๋๊ธฐ ์ง์)""" |
|
try: |
|
await check_faiss_index() |
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
def _get_vector(): |
|
model = get_embedding_model() |
|
vector = model.encode(query, convert_to_numpy=True).astype("float32") |
|
norm = np.linalg.norm(vector) |
|
normalized_vector = vector / norm |
|
return normalized_vector, norm |
|
|
|
|
|
normalized_vector, norm = await loop.run_in_executor(thread_pool, _get_vector) |
|
|
|
|
|
def _search(): |
|
return faiss_index.search(np.array([normalized_vector]), top_k) |
|
|
|
distances, indices = await loop.run_in_executor(thread_pool, _search) |
|
|
|
|
|
results = [] |
|
for i, (idx, dist) in enumerate(zip(indices[0], distances[0])): |
|
if idx < len(indexed_items): |
|
item_name = indexed_items[idx] |
|
results.append({ |
|
"rank": i + 1, |
|
"index": int(idx), |
|
"item_name": item_name, |
|
"distance/score": float(dist) |
|
}) |
|
|
|
|
|
def _find_matches(): |
|
contains = [item for item in indexed_items if query.lower() in item.lower()][:5] |
|
exact = [item for item in indexed_items if query.lower() == item.lower()] |
|
return contains, exact |
|
|
|
contains_query, exact_matches = await loop.run_in_executor(thread_pool, _find_matches) |
|
|
|
return { |
|
"query": query, |
|
"vector_norm": float(norm), |
|
"contains_query": contains_query, |
|
"exact_matches": exact_matches, |
|
"results": results |
|
} |
|
except Exception as e: |
|
logger.error(f"โ ์ธ๋ฑ์ค ๋๋ฒ๊น
์ค ์ค๋ฅ: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"์ธ๋ฑ์ค ๋๋ฒ๊น
์ค๋ฅ: {str(e)}") |
|
|
|
|
|
@app.get("/api/text_search") |
|
async def text_search(query: str, top_k: int = 10): |
|
"""๋จ์ ํ
์คํธ ํฌํจ ๊ฒ์ API (๋น๋๊ธฐ ์ง์)""" |
|
try: |
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
def _text_search(): |
|
|
|
matched_items = [] |
|
for idx, item_name in enumerate(indexed_items): |
|
if query.lower() in item_name.lower(): |
|
try: |
|
item_seq = active_sale_items.loc[active_sale_items["ITEMNAME"] == item_name, "ITEMSEQ"].values[0] |
|
matched_items.append({"ITEMSEQ": item_seq, "ITEMNAME": item_name, "match_type": "contains"}) |
|
except (IndexError, KeyError): |
|
continue |
|
|
|
|
|
exact_matches = [] |
|
partial_matches = [] |
|
|
|
for item in matched_items: |
|
if query.lower() == item["ITEMNAME"].lower(): |
|
item["match_type"] = "exact" |
|
exact_matches.append(item) |
|
else: |
|
partial_matches.append(item) |
|
|
|
|
|
return exact_matches + partial_matches |
|
|
|
|
|
results = await loop.run_in_executor(thread_pool, _text_search) |
|
|
|
logger.info(f"๐ ํ
์คํธ ๊ฒ์ ๊ฒฐ๊ณผ: {len(results)}๊ฐ ์ฐพ์, ์ฟผ๋ฆฌ: '{query}'") |
|
|
|
return { |
|
"query": query, |
|
"total_matches": len(results), |
|
"results": results[:top_k] |
|
} |
|
except Exception as e: |
|
logger.error(f"โ ํ
์คํธ ๊ฒ์ ์ค ์ค๋ฅ: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"ํ
์คํธ ๊ฒ์ ์ค๋ฅ: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
try: |
|
|
|
loop = asyncio.new_event_loop() |
|
if not loop.run_until_complete(load_faiss_index()): |
|
logger.warning("โ ๏ธ ๊ธฐ์กด ์ธ๋ฑ์ค ๋ก๋์ ์คํจํ์ต๋๋ค. ์ฆ์ ์ ์ธ๋ฑ์ค๋ฅผ ๊ตฌ์ถํฉ๋๋ค.") |
|
|
|
loop.run_until_complete(rebuild_faiss_index()) |
|
logger.info("โ
FAISS ์ธ๋ฑ์ค ์์ฑ ์๋ฃ!") |
|
else: |
|
logger.info("โ
๊ธฐ์กด ์ธ๋ฑ์ค๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ๋ก๋ํ์ต๋๋ค.") |
|
loop.close() |
|
except Exception as e: |
|
logger.error(f"โ ์ธ๋ฑ์ค ์ด๊ธฐ ๊ตฌ์ถ ์คํจ: {e}") |
|
logger.warning("โ ๏ธ ์ธ๋ฑ์ค ์์ด ์์ํฉ๋๋ค. ๊ฒ์ ๊ธฐ๋ฅ์ด ์ ํ๋ ์ ์์ต๋๋ค.") |
|
|
|
import uvicorn |
|
uvicorn.run("searchAsyncTunning:app", |
|
host="0.0.0.0", |
|
port=7860, |
|
workers=8, |
|
loop="uvloop", |
|
http="httptools", |
|
limit_concurrency=200, |
|
timeout_keep_alive=5) |