traning / searchHybrid.py
aikobay's picture
Update searchHybrid.py
8b22dd0 verified
raw
history blame contribute delete
25.3 kB
import os
import torch
import pandas as pd
import logging
import faiss
import numpy as np
import time
import gensim
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from datasets import load_dataset
from huggingface_hub import login, hf_hub_download, HfApi, create_repo
from keybert import KeyBERT
from sentence_transformers import SentenceTransformer
from joblib import Parallel, delayed
from tqdm import tqdm
import tempfile
import re
import sys
# ✅ 로그 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ✅ FastAPI 인스턴스 생성
app = FastAPI(title="🚀 KeyBERT + Word2Vec 기반 FAISS 검색 API", version="1.1")
# ✅ GPU 사용 여부 확인
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"🚀 실행 디바이스: {device.upper()}")
# ✅ Hugging Face 로그인
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if HF_API_TOKEN:
logger.info("🔑 Hugging Face API 로그인 중...")
login(token=HF_API_TOKEN)
else:
logger.error("❌ HF_API_TOKEN이 설정되지 않았습니다. 일부 기능이 제한될 수 있습니다.")
# ✅ Word2Vec 모델 로드
word2vec_model = None
try:
logger.info("🔄 Word2Vec 모델 로드 중...")
MODEL_REPO = "aikobay/item-model"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="item_vectors.bin", repo_type="dataset")
word2vec_model = gensim.models.KeyedVectors.load_word2vec_format(model_path, binary=True)
logger.info(f"✅ Word2Vec 모델 로드 완료! 단어 수: {len(word2vec_model.key_to_index)}")
except Exception as e:
logger.error(f"❌ Word2Vec 모델 로드 실패: {e}")
# ✅ KeyBERT 모델 로드
logger.info("🔄 KeyBERT 모델 로드 중...")
kw_model = KeyBERT("paraphrase-multilingual-MiniLM-L12-v2")
original_embedding_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
logger.info("✅ KeyBERT 모델 로드 완료!")
# ✅ 한국어 특화 임베딩 모델로 교체
embedding_model = None
try:
logger.info("🔄 한국어 특화 임베딩 모델로 교체 시도...")
# 한국어 특화 모델 로드 시도 (실패시 기존 모델 유지)
embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask")
logger.info("✅ 한국어 특화 임베딩 모델 로드 완료!")
except Exception as e:
logger.warning(f"⚠️ 한국어 특화 모델 로드 실패, 기존 모델 유지: {e}")
embedding_model = original_embedding_model
# ✅ 진행 중인 경매 상품 데이터 로드
def load_huggingface_jsonl(dataset_name, split="train"):
"""Hugging Face Hub에서 데이터셋 로드"""
try:
repo_id = f"aikobay/{dataset_name}"
dataset = load_dataset(repo_id, split=split)
df = dataset.to_pandas().dropna()
return df
except Exception as e:
logger.error(f"❌ 데이터 로드 중 오류 발생: {e}")
return pd.DataFrame()
try:
active_sale_items = load_huggingface_jsonl("initial_saleitem_dataset")
if active_sale_items.empty:
logger.error("❌ 데이터셋이 비어 있습니다. 프로그램을 종료합니다.")
exit(1)
logger.info(f"✅ 경매 상품 데이터 로드 완료! 총 {len(active_sale_items)}개 상품")
except Exception as e:
logger.error(f"❌ 상품 데이터 로드 실패: {e}")
exit(1)
# ✅ FAISS 인덱스 초기화
faiss_index = None
indexed_items = []
# ✅ 멀티코어 벡터화 함수
def encode_texts_parallel(texts, batch_size=512):
"""멀티 프로세싱을 활용한 벡터화 속도 최적화"""
num_cores = os.cpu_count() # CPU 개수 확인
logger.info(f"🔄 멀티코어 벡터화 진행 (코어 수: {num_cores})")
def encode_batch(batch):
return embedding_model.encode(batch, convert_to_numpy=True)
# 배치 단위로 병렬 처리
text_batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
embeddings = Parallel(n_jobs=num_cores)(delayed(encode_batch)(batch) for batch in text_batches)
return np.vstack(embeddings).astype("float32")
# ✅ FAISS 인덱스 저장 함수 (Hugging Face Hub)
def save_faiss_index():
"""FAISS 인덱스를 Hugging Face Hub에 저장"""
global faiss_index, indexed_items
if faiss_index is None or not indexed_items:
logger.error("❌ 저장할 FAISS 인덱스가 없습니다.")
return False
try:
# 레포지토리 ID
repo_id = os.getenv("HF_INDEX_REPO", "aikobay/saleitem_faiss_index")
# HfApi 객체 생성
api = HfApi()
# 레포지토리 존재 여부 확인 및 생성
try:
api.repo_info(repo_id=repo_id, repo_type="dataset")
logger.info(f"✅ 기존 레포지토리 사용: {repo_id}")
except Exception:
logger.info(f"🔄 레포지토리가 존재하지 않아 새로 생성합니다: {repo_id}")
create_repo(
repo_id=repo_id,
repo_type="dataset",
private=True,
exist_ok=True
)
logger.info(f"✅ 레포지토리 생성 완료: {repo_id}")
# 임시 파일로 먼저 로컬에 저장
with tempfile.TemporaryDirectory() as temp_dir:
index_path = os.path.join(temp_dir, "faiss_index.bin")
items_path = os.path.join(temp_dir, "indexed_items.txt")
# FAISS 인덱스 저장
faiss.write_index(faiss_index, index_path)
# 아이템 목록 저장
with open(items_path, "w", encoding="utf-8") as f:
f.write("\n".join(indexed_items))
# README 파일 생성
readme_path = os.path.join(temp_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(f"""# FAISS 인덱스 저장소
이 저장소는 상품 검색을 위한 FAISS 인덱스와 관련 데이터를 포함하고 있습니다.
- 최종 업데이트: {pd.Timestamp.now()}
- 인덱스 항목 수: {len(indexed_items)}
- 모델: KeyBERT + Word2Vec
이 저장소는 'aikobay/initial_saleitem_dataset'의 상품 데이터를 기반으로 생성된 벡터 인덱스를 저장하기 위해 자동 생성되었습니다.
""")
# 파일 업로드
for file_path, file_name in [
(index_path, "faiss_index.bin"),
(items_path, "indexed_items.txt"),
(readme_path, "README.md")
]:
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=file_name,
repo_id=repo_id,
repo_type="dataset"
)
logger.info(f"✅ FAISS 인덱스가 Hugging Face Hub에 저장되었습니다. 레포: {repo_id}")
return True
except Exception as e:
logger.error(f"❌ FAISS 인덱스 Hub 저장 중 오류 발생: {e}")
# 로컬에 백업 저장 시도
try:
local_path = os.path.join(os.getcwd(), "faiss_index.bin")
faiss.write_index(faiss_index, local_path)
with open("indexed_items.txt", "w", encoding="utf-8") as f:
f.write("\n".join(indexed_items))
logger.info(f"✅ FAISS 인덱스가 로컬에 백업 저장되었습니다: {local_path}")
return True
except Exception as local_err:
logger.error(f"❌ 로컬 백업 저장도 실패: {local_err}")
return False
# ✅ FAISS 인덱스 로드 함수 (Hugging Face Hub)
def load_faiss_index():
"""Hugging Face Hub에서 FAISS 인덱스를 로드"""
global faiss_index, indexed_items
# 레포지토리 ID
repo_id = os.getenv("HF_INDEX_REPO", "aikobay/saleitem_faiss_index")
try:
# 레포지토리 존재 확인
api = HfApi()
try:
api.repo_info(repo_id=repo_id, repo_type="dataset")
logger.info(f"✅ FAISS 인덱스 레포지토리 확인: {repo_id}")
except Exception as repo_err:
logger.warning(f"⚠️ 레포지토리가 존재하지 않습니다: {repo_err}")
raise FileNotFoundError("Hub 레포지토리가 존재하지 않습니다")
# Hub에서 파일 다운로드
index_path = hf_hub_download(
repo_id=repo_id,
filename="faiss_index.bin",
repo_type="dataset"
)
items_path = hf_hub_download(
repo_id=repo_id,
filename="indexed_items.txt",
repo_type="dataset"
)
# 파일 로드
faiss_index = faiss.read_index(index_path)
with open(items_path, "r", encoding="utf-8") as f:
indexed_items = f.read().splitlines()
logger.info(f"✅ FAISS 인덱스가 Hub에서 로드되었습니다. 총 {len(indexed_items)}개 상품")
return True
except Exception as e:
logger.warning(f"⚠️ Hub에서 FAISS 인덱스 로드 중 오류 발생: {e}")
# 로컬 파일 확인
try:
local_index_path = "faiss_index.bin"
local_items_path = "indexed_items.txt"
if os.path.exists(local_index_path) and os.path.exists(local_items_path):
faiss_index = faiss.read_index(local_index_path)
with open(local_items_path, "r", encoding="utf-8") as f:
indexed_items = f.read().splitlines()
logger.info(f"✅ 로컬 FAISS 인덱스 로드 성공. 총 {len(indexed_items)}개 상품")
return True
else:
logger.warning("⚠️ 로컬 FAISS 인덱스 파일이 존재하지 않습니다.")
return False
except Exception as local_err:
logger.error(f"❌ 로컬 FAISS 인덱스 로드 중 오류: {local_err}")
return False
# ✅ FAISS 인덱스 구축
def rebuild_faiss_index():
"""FAISS 인덱스를 새롭게 구축"""
global faiss_index, indexed_items, active_sale_items
logger.info("🔄 FAISS 인덱스를 재구축 중...")
# 최신 상품 데이터 로드
active_sale_items = load_huggingface_jsonl("initial_saleitem_dataset")
if active_sale_items.empty:
logger.error("❌ 상품 데이터를 로드할 수 없습니다.")
raise RuntimeError("상품 데이터 로드 실패")
# 상품명 목록 추출
item_names = active_sale_items["ITEMNAME"].tolist()
indexed_items = item_names
logger.info(f"🔹 총 {len(item_names)}개 상품 벡터화 시작...")
# 벡터화 및 인덱스 구축 - 코사인 유사도 사용
item_vectors = encode_texts_parallel(item_names)
# 벡터 정규화 (코사인 유사도를 위해)
norms = np.linalg.norm(item_vectors, axis=1, keepdims=True)
normalized_vectors = item_vectors / norms
# Inner Product 기반 인덱스 사용 (코사인 유사도를 위해)
faiss_index = faiss.IndexFlatIP(item_vectors.shape[1])
faiss_index.add(normalized_vectors)
logger.info(f"✅ FAISS 인덱스 구축 완료! 총 {len(indexed_items)}개 항목.")
# 구축 후 Hub에 저장
save_faiss_index()
return True
# ✅ FAISS 인덱스 상태 확인 및 필요시에만 구축
def check_faiss_index():
"""FAISS 인덱스가 존재하는지 확인하고 없으면 구축"""
global faiss_index
if faiss_index is None:
# Hub에서 로드 시도
if not load_faiss_index():
# 로드 실패 시 새로 구축
logger.warning("⚠️ 저장된 인덱스가 없어 새로 구축합니다.")
rebuild_faiss_index()
# 모든 과정 후에도 인덱스가 None이면 오류
if faiss_index is None:
raise RuntimeError("FAISS 인덱스 초기화에 실패했습니다.")
# ✅ KeyBERT 기반 핵심 키워드 추출
def extract_keywords(query: str, top_n: int = 3):
"""KeyBERT를 사용한 핵심 키워드 추출"""
keywords = kw_model.extract_keywords(query, keyphrase_ngram_range=(1,2), top_n=top_n)
return [k[0] for k in keywords]
# ✅ Word2Vec 기반 키워드 확장 함수
def expand_keywords_with_word2vec(keywords: list, max_new=5):
"""Word2Vec 모델을 사용한 키워드 확장"""
if word2vec_model is None:
logger.warning("⚠️ Word2Vec 모델이 로드되지 않아 확장을 수행하지 않습니다.")
return keywords
expanded_keywords = list(keywords) # 복사본 생성
try:
for keyword in keywords:
# 단어가 모델에 있는지 확인
if keyword in word2vec_model:
# 유사 단어 찾기
similar_words = word2vec_model.most_similar(keyword, topn=max_new)
expanded_keywords.extend([word for word, _ in similar_words])
elif len(keyword.split()) > 1:
# 복합어인 경우 개별 단어로 시도
for word in keyword.split():
if word in word2vec_model and len(word) > 1:
similar_words = word2vec_model.most_similar(word, topn=2)
expanded_keywords.extend([w for w, _ in similar_words])
# 중복 제거
expanded_keywords = list(set(expanded_keywords))
logger.info(f"🔍 Word2Vec 확장 키워드: {expanded_keywords}")
return expanded_keywords
except Exception as e:
logger.error(f"❌ Word2Vec 키워드 확장 중 오류 발생: {e}")
return keywords
# ✅ FAISS 검색 함수
def search_faiss_with_keywords(query: str, top_k: int = 5, keywords=None, expanded_keywords=None):
"""키워드 기반 FAISS 검색 수행"""
# FAISS 인덱스 확인 (없으면 로드하거나 구축)
check_faiss_index()
start_time = time.time()
# 키워드가 없으면 KeyBERT로 추출
if keywords is None:
keywords = extract_keywords(query)
logger.info(f"🔍 KeyBERT 추출 키워드: {keywords}")
# 확장 키워드가 없으면 Word2Vec으로 확장
if expanded_keywords is None:
expanded_keywords = expand_keywords_with_word2vec(keywords)
# 원본 쿼리 벡터
query_vector = embedding_model.encode(query, convert_to_numpy=True).astype("float32")
# 벡터 정규화
query_vector = query_vector / np.linalg.norm(query_vector)
query_vector = np.array([query_vector])
# 원본 쿼리로 검색 - 가중치를 더 줌
distances, query_indices = faiss_index.search(query_vector, top_k * 2) # 더 많은 후보 검색
# 결과 매핑 - 거리에 따라 가중치 부여
recommendations = []
scored_results = {}
# 원본 쿼리 결과 처리 (가중치 2배)
for idx, (i, dist) in enumerate(zip(query_indices[0], distances[0])):
if i < len(indexed_items):
item_name = indexed_items[i]
score = 2.0 * dist # 코사인 유사도는 값이 높을수록 유사, 원본 쿼리 가중치 2배
scored_results[item_name] = score
# 각 확장 키워드별 검색 수행 (원본 쿼리보다 낮은 가중치)
for keyword in expanded_keywords:
keyword_vector = embedding_model.encode(keyword, convert_to_numpy=True).astype("float32")
keyword_vector = keyword_vector / np.linalg.norm(keyword_vector)
keyword_vector = np.array([keyword_vector])
k_distances, k_indices = faiss_index.search(keyword_vector, top_k)
# 결과 처리
for idx, (i, dist) in enumerate(zip(k_indices[0], k_distances[0])):
if i < len(indexed_items):
item_name = indexed_items[i]
# 이미 점수가 있으면 가중치를 적용한 평균 내기, 없으면 새로 추가
if item_name in scored_results:
scored_results[item_name] += 0.5 * dist # 확장 키워드는 가중치 0.5배
else:
scored_results[item_name] = 0.5 * dist
# 점수에 따라 정렬
sorted_results = sorted(scored_results.items(), key=lambda x: x[1], reverse=True)
# 최종 결과 생성
# 점수 임계값 설정 - 너무 낮은 점수의 결과는 제외
min_score_threshold = 0.3 # 임계값 설정
for item_name, score in sorted_results:
# 너무 낮은 점수는 건너뜀
if score < min_score_threshold:
continue
try:
item_seq = active_sale_items.loc[active_sale_items["ITEMNAME"] == item_name, "ITEMSEQ"].values[0]
recommendations.append({"ITEMSEQ": item_seq, "ITEMNAME": item_name, "score": float(score)})
except (IndexError, KeyError) as e:
continue # 매핑 실패 항목은 건너뜀
# 결과가 부족하면 쿼리 단어가 포함된 항목 직접 검색
if len(recommendations) < top_k:
direct_matches = []
for idx, item_name in enumerate(indexed_items):
# 쿼리 단어가 상품명에 포함되어 있는지 확인
if query.lower() in item_name.lower():
try:
item_seq = active_sale_items.loc[active_sale_items["ITEMNAME"] == item_name, "ITEMSEQ"].values[0]
# 이미 결과에 있는지 확인
if not any(r["ITEMNAME"] == item_name for r in recommendations):
direct_matches.append({"ITEMSEQ": item_seq, "ITEMNAME": item_name, "score": 1.0})
except (IndexError, KeyError):
continue
# 직접 매치 결과 추가
recommendations.extend(direct_matches)
logger.info(f"🔍 검색 수행 완료! 걸린 시간: {time.time() - start_time:.4f}초, 추천 {len(recommendations)}개")
return recommendations[:top_k] # 요청된 top_k개 이하로 제한
# ✅ API 요청 모델
class RecommendRequest(BaseModel):
search_query: str
top_k: int = 5
use_expansion: bool = True # 키워드 확장 사용 여부
# ✅ 추천 API 엔드포인트
@app.post("/api/recommend")
async def recommend(request: RecommendRequest):
"""Word2Vec 기반 FAISS 검색/추천 API"""
try:
# 로그에 요청 정보 기록
logger.info(f"📝 검색 요청: '{request.search_query}' (top_k: {request.top_k}, 확장: {request.use_expansion})")
# 키워드 추출
keywords = extract_keywords(request.search_query)
# 키워드 확장 사용 여부에 따라 처리
if request.use_expansion and word2vec_model is not None:
expanded_keywords = expand_keywords_with_word2vec(keywords)
else:
expanded_keywords = keywords
logger.info(f"🔍 키워드 확장 없이 진행: {keywords}")
# FAISS 검색 수행
recommendations = search_faiss_with_keywords(
request.search_query,
request.top_k,
keywords,
expanded_keywords
)
# 결과 로깅 강화
logger.info(f"🔍 검색 결과: {[r['ITEMNAME'] for r in recommendations]}")
return {
"query": request.search_query,
"recommendations": recommendations,
"keywords": keywords,
"expanded_keywords": expanded_keywords
}
except Exception as e:
logger.error(f"❌ 추천 처리 중 오류: {str(e)}")
raise HTTPException(status_code=500, detail=f"추천 오류: {str(e)}")
# ✅ 유사 단어 검색 API
@app.post("/api/similar_words")
async def similar_words(word: str, top_k: int = 10):
"""Word2Vec 모델을 사용한 유사 단어 검색 API"""
try:
if word2vec_model is None:
return {"error": "Word2Vec 모델이 로드되지 않았습니다."}
if word not in word2vec_model:
return {"word": word, "similar_words": [], "message": "단어가 모델에 없습니다."}
similar = word2vec_model.most_similar(word, topn=top_k)
result = [{"word": w, "similarity": float(s)} for w, s in similar]
return {"word": word, "similar_words": result}
except Exception as e:
logger.error(f"❌ 유사 단어 검색 중 오류: {str(e)}")
raise HTTPException(status_code=500, detail=f"유사 단어 검색 오류: {str(e)}")
# ✅ FAISS 인덱스 갱신 API (명시적으로 요청할 때만 실행)
@app.post("/api/update_index")
async def update_index():
"""FAISS 인덱스를 새롭게 구축 (명시적 요청 시에만)"""
try:
# 인덱스 재구축
rebuild_faiss_index()
return {"message": "✅ FAISS 인덱스 업데이트 및 저장 완료!"}
except Exception as e:
logger.exception("❌ [API] 인덱스 업데이트 처리 중 예외 발생")
raise HTTPException(status_code=500, detail=f"인덱스 업데이트 실패: {str(e)}")
# ✅ 인덱스 디버깅 API
@app.get("/api/debug_index")
async def debug_index(query: str, top_k: int = 20):
"""인덱스 디버깅을 위한 API"""
try:
check_faiss_index()
# 원본 벡터 생성
vector = embedding_model.encode(query, convert_to_numpy=True).astype("float32")
# 벡터 정규화
norm = np.linalg.norm(vector)
normalized_vector = vector / norm
# 원본 쿼리로 검색
distances, indices = faiss_index.search(np.array([normalized_vector]), top_k)
# 결과 매핑
results = []
for i, (idx, dist) in enumerate(zip(indices[0], distances[0])):
if idx < len(indexed_items):
item_name = indexed_items[idx]
results.append({
"rank": i + 1,
"index": int(idx),
"item_name": item_name,
"distance/score": float(dist)
})
# 데이터셋에 해당 단어가 있는지 확인
contains_query = [item for item in indexed_items if query.lower() in item.lower()]
exact_matches = [item for item in indexed_items if query.lower() == item.lower()]
return {
"query": query,
"vector_norm": float(norm),
"contains_query": contains_query[:5], # 처음 5개만
"exact_matches": exact_matches,
"results": results
}
except Exception as e:
logger.error(f"❌ 인덱스 디버깅 중 오류: {str(e)}")
raise HTTPException(status_code=500, detail=f"인덱스 디버깅 오류: {str(e)}")
# ✅ 문자열 포함 검색 API
@app.get("/api/text_search")
async def text_search(query: str, top_k: int = 10):
"""단순 텍스트 포함 검색 API (FAISS 검색 결과가 이상할 때 대체용)"""
try:
# 단순 텍스트 포함 검색
matched_items = []
for idx, item_name in enumerate(indexed_items):
if query.lower() in item_name.lower():
try:
item_seq = active_sale_items.loc[active_sale_items["ITEMNAME"] == item_name, "ITEMSEQ"].values[0]
matched_items.append({"ITEMSEQ": item_seq, "ITEMNAME": item_name, "match_type": "contains"})
except (IndexError, KeyError):
continue
# 정확히 일치하는 항목을 앞으로
exact_matches = []
partial_matches = []
for item in matched_items:
if query.lower() == item["ITEMNAME"].lower():
item["match_type"] = "exact"
exact_matches.append(item)
else:
partial_matches.append(item)
# 결합 및 제한
results = exact_matches + partial_matches
logger.info(f"🔍 텍스트 검색 결과: {len(results)}개 찾음, 쿼리: '{query}'")
return {
"query": query,
"total_matches": len(results),
"results": results[:top_k]
}
except Exception as e:
logger.error(f"❌ 텍스트 검색 중 오류: {str(e)}")
raise HTTPException(status_code=500, detail=f"텍스트 검색 오류: {str(e)}")
# ✅ FastAPI 실행
if __name__ == "__main__":
# 서버 시작 시 저장된 인덱스 로드 시도
if not load_faiss_index():
logger.warning("⚠️ 기존 인덱스 로드에 실패했습니다. 즉시 새 인덱스를 구축합니다.")
try:
# 인덱스 즉시 재구축
rebuild_faiss_index()
logger.info("✅ FAISS 인덱스 생성 완료!")
except Exception as e:
logger.error(f"❌ 인덱스 초기 구축 실패: {e}")
logger.warning("⚠️ 인덱스 없이 시작합니다. 검색 기능이 제한될 수 있습니다.")
else:
logger.info("✅ 기존 인덱스를 성공적으로 로드했습니다.")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)