docker-api / app_gemma.py
dasomaru's picture
Upload folder using huggingface_hub
06696b5 verified
raw
history blame contribute delete
2.44 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# from retriever.vectordb_rerank import search_documents # ๐Ÿง  RAG ๊ฒ€์ƒ‰๊ธฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
from services.rag_pipeline import rag_pipeline
model_name = "dasomaru/gemma-3-4bit-it-demo"
# 1. ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € 1ํšŒ ๋กœ๋”ฉ
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# ๐Ÿš€ model์€ CPU๋กœ๋งŒ ๋จผ์ € ์˜ฌ๋ฆผ (GPU ์•„์ง ์—†์Œ)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ
device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น
trust_remote_code=True,
)
# 2. ์บ์‹œ ๊ด€๋ฆฌ
search_cache = {}
@spaces.GPU(duration=300)
def generate_response(query: str):
tokenizer = AutoTokenizer.from_pretrained(
"dasomaru/gemma-3-4bit-it-demo",
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
"dasomaru/gemma-3-4bit-it-demo",
torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ
device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น
trust_remote_code=True,
)
model.to("cuda")
if query in search_cache:
print(f"โšก ์บ์‹œ ์‚ฌ์šฉ: '{query}'")
return search_cache[query]
# ๐Ÿ”ฅ rag_pipeline์„ ํ˜ธ์ถœํ•ด์„œ ๊ฒ€์ƒ‰ + ์ƒ์„ฑ
# ๊ฒ€์ƒ‰
top_k = 5
results = rag_pipeline(query, top_k=top_k)
# ๊ฒฐ๊ณผ๊ฐ€ list์ผ ๊ฒฝ์šฐ ํ•ฉ์น˜๊ธฐ
if isinstance(results, list):
results = "\n\n".join(results)
search_cache[query] = results
# return results
inputs = tokenizer(results, return_tensors="pt").to(model.device) # โœ… model.device
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
top_k=50,
do_sample=True,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 3. Gradio ์ธํ„ฐํŽ˜์ด์Šค
demo = gr.Interface(
fn=generate_response,
# inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”"),
inputs="text",
outputs="text",
title="Law RAG Assistant",
description="๋ฒ•๋ น ๊ธฐ๋ฐ˜ RAG ํŒŒ์ดํ”„๋ผ์ธ ํ…Œ์ŠคํŠธ",
)
# demo.launch(server_name="0.0.0.0", server_port=7860) # ๐Ÿš€ API ๋ฐฐํฌ ์ค€๋น„ ๊ฐ€๋Šฅ
# demo.launch()
demo.launch(debug=True)