import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer # from retriever.vectordb_rerank import search_documents # ๐Ÿง  RAG ๊ฒ€์ƒ‰๊ธฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ from services.rag_pipeline import rag_pipeline model_name = "dasomaru/gemma-3-4bit-it-demo" # 1. ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € 1ํšŒ ๋กœ๋”ฉ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # ๐Ÿš€ model์€ CPU๋กœ๋งŒ ๋จผ์ € ์˜ฌ๋ฆผ (GPU ์•„์ง ์—†์Œ) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น trust_remote_code=True, ) # 2. ์บ์‹œ ๊ด€๋ฆฌ search_cache = {} @spaces.GPU(duration=300) def generate_response(query: str): tokenizer = AutoTokenizer.from_pretrained( "dasomaru/gemma-3-4bit-it-demo", trust_remote_code=True, ) model = AutoModelForCausalLM.from_pretrained( "dasomaru/gemma-3-4bit-it-demo", torch_dtype=torch.float16, # 4bit model์ด๋‹ˆ๊นŒ device_map="auto", # โœ… ์ค‘์š”: ์ž๋™์œผ๋กœ GPU ํ• ๋‹น trust_remote_code=True, ) model.to("cuda") if query in search_cache: print(f"โšก ์บ์‹œ ์‚ฌ์šฉ: '{query}'") return search_cache[query] # ๐Ÿ”ฅ rag_pipeline์„ ํ˜ธ์ถœํ•ด์„œ ๊ฒ€์ƒ‰ + ์ƒ์„ฑ # ๊ฒ€์ƒ‰ top_k = 5 results = rag_pipeline(query, top_k=top_k) # ๊ฒฐ๊ณผ๊ฐ€ list์ผ ๊ฒฝ์šฐ ํ•ฉ์น˜๊ธฐ if isinstance(results, list): results = "\n\n".join(results) search_cache[query] = results # return results inputs = tokenizer(results, return_tensors="pt").to(model.device) # โœ… model.device outputs = model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9, top_k=50, do_sample=True, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # 3. Gradio ์ธํ„ฐํŽ˜์ด์Šค demo = gr.Interface( fn=generate_response, # inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”"), inputs="text", outputs="text", title="Law RAG Assistant", description="๋ฒ•๋ น ๊ธฐ๋ฐ˜ RAG ํŒŒ์ดํ”„๋ผ์ธ ํ…Œ์ŠคํŠธ", ) # demo.launch(server_name="0.0.0.0", server_port=7860) # ๐Ÿš€ API ๋ฐฐํฌ ์ค€๋น„ ๊ฐ€๋Šฅ # demo.launch() demo.launch(debug=True)