File size: 11,823 Bytes
f852c4e
9267136
e507147
 
eadc038
e507147
 
75b2e84
9267136
eadc038
 
e507147
eadc038
b94c22c
 
 
 
 
e507147
eadc038
 
 
 
 
 
f81d0d2
e507147
eadc038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e507147
 
 
 
eadc038
e507147
 
 
 
b94c22c
 
 
 
 
 
 
 
eadc038
e507147
 
b94c22c
e507147
 
b94c22c
 
eadc038
 
 
b94c22c
 
 
 
 
eadc038
e507147
 
b94c22c
e507147
 
 
 
eadc038
e507147
 
eadc038
 
e507147
 
 
 
 
eadc038
e507147
 
 
 
eadc038
e507147
 
 
eadc038
e507147
eadc038
e507147
 
eadc038
e507147
 
 
eadc038
 
e507147
eadc038
e507147
 
 
eadc038
 
 
e507147
 
eadc038
e507147
 
eadc038
e507147
 
eadc038
e507147
 
 
 
eadc038
 
 
 
 
 
 
 
 
 
 
 
 
 
e507147
 
 
 
 
 
 
 
eadc038
e507147
 
 
 
 
 
 
 
eadc038
e507147
 
eadc038
e507147
 
 
75b2e84
 
9267136
e507147
eadc038
 
e507147
 
 
 
eadc038
e507147
eadc038
 
 
 
 
 
7c4abe8
 
 
 
 
 
eadc038
 
 
e507147
 
 
 
 
eadc038
 
 
 
 
 
 
 
e507147
 
eadc038
e507147
 
 
eadc038
e507147
 
eadc038
e507147
 
 
 
 
 
 
 
 
 
eadc038
e507147
 
 
 
 
eadc038
e507147
 
eadc038
e507147
 
 
 
 
eadc038
e507147
 
 
 
 
eadc038
e507147
 
eadc038
e507147
 
eadc038
e507147
 
 
eadc038
 
 
 
7c4abe8
 
 
 
 
52ba3fd
eadc038
7c4abe8
eadc038
7c4abe8
eadc038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f852c4e
e507147
 
 
eadc038
e507147
eadc038
e507147
eadc038
75b2e84
e507147
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
import os
import datetime
import time
import spaces  # Import spaces module for GPU acceleration

# --- Configuration ---
MODEL_ID = "naver-hyperclovax/HyperCLOVAX-SEED-Text-Instruct-0.5B"
MAX_NEW_TOKENS = 512
USE_GPU = True  # Enable GPU usage

# Hugging Face 토큰 μ„€μ • - ν™˜κ²½ λ³€μˆ˜μ—μ„œ κ°€μ Έμ˜€κΈ°
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    print("κ²½κ³ : HF_TOKEN ν™˜κ²½ λ³€μˆ˜κ°€ μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€. λΉ„κ³΅κ°œ λͺ¨λΈμ— μ ‘κ·Όν•  수 없을 수 μžˆμŠ΅λ‹ˆλ‹€.")

# --- Environment setup ---
print("--- Environment Setup ---")
device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu")
print(f"PyTorch version: {torch.__version__}")
print(f"Running on device: {device}")
print(f"Torch Threads: {torch.get_num_threads()}")
print(f"HF_TOKEN μ„€μ • μ—¬λΆ€: {'있음' if HF_TOKEN else 'μ—†μŒ'}")

# Custom CSS for improved UI
custom_css = """
.gradio-container {
    max-width: 850px !important;
    margin: auto;
}
.gr-chat {
    border-radius: 10px;
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
.user-message {
    background-color: #f0f7ff !important;
    border-radius: 8px;
}
.assistant-message {
    background-color: #f9f9f9 !important;
    border-radius: 8px;
}
.gr-button.primary-button {
    background-color: #1f4e79 !important;
}
.gr-form {
    padding: 20px;
    border-radius: 10px;
    box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
}
#intro-message {
    text-align: center;
    margin-bottom: 20px;
    padding: 15px;
    background: linear-gradient(135deg, #e8f4ff 0%, #f0f7ff 100%);
    border-radius: 10px;
    border-left: 4px solid #1f4e79;
}
.footer {
    text-align: center;
    margin-top: 20px;
    font-size: 0.8em;
    color: #666;
}
"""

# --- Model and Tokenizer Loading ---
print(f"--- Loading Model: {MODEL_ID} ---")
print("This might take a few minutes, especially on the first launch...")

model = None
tokenizer = None
load_successful = False
stop_token_ids_list = []  # Initialize stop_token_ids_list

try:
    start_load_time = time.time()
    
    # ν† ν¬λ‚˜μ΄μ € λ‘œλ”©
    tokenizer_kwargs = {
        "trust_remote_code": True
    }
    
    # HF_TOKEN이 μ„€μ •λ˜μ–΄ 있으면 μΆ”κ°€
    if HF_TOKEN:
        tokenizer_kwargs["token"] = HF_TOKEN
    
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        **tokenizer_kwargs
    )
    
    # λͺ¨λΈ λ‘œλ”©
    model_kwargs = {
        "trust_remote_code": True,
        "device_map": "auto" if device.type == "cuda" else "cpu",
        "torch_dtype": torch.float16 if device.type == "cuda" else torch.float32,
    }
    
    # HF_TOKEN이 μ„€μ •λ˜μ–΄ 있으면 μΆ”κ°€
    if HF_TOKEN:
        model_kwargs["token"] = HF_TOKEN
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        **model_kwargs
    )
    
    model.eval()
    load_time = time.time() - start_load_time
    print(f"--- Model and Tokenizer Loaded Successfully in {load_time:.2f} seconds ---")
    load_successful = True

    # --- Stop Token Configuration ---
    stop_token_strings = ["<|endofturn|>", "<|stop|>"]
    temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]

    if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids:
        temp_stop_ids.append(tokenizer.eos_token_id)
    elif tokenizer.eos_token_id is None:
         print("Warning: tokenizer.eos_token_id is None. Cannot add to stop tokens.")

    stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None]

    if not stop_token_ids_list:
        print("Warning: Could not find any stop token IDs. Using default EOS if available, otherwise generation might not stop correctly.")
        if tokenizer.eos_token_id is not None:
            stop_token_ids_list = [tokenizer.eos_token_id]
        else:
             print("Error: No stop tokens found, including default EOS. Generation may run indefinitely.")

    print(f"Using Stop Token IDs: {stop_token_ids_list}")

except Exception as e:
    print(f"!!! Error loading model: {e}")
    if 'model' in locals() and model is not None: del model
    if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
    gc.collect()
    # Raise Gradio error to display in the Space UI if loading fails
    raise gr.Error(f"Failed to load the model {MODEL_ID}. Cannot start the application. Error: {e}")

# --- System Prompt Definition ---
def get_system_prompt():
    current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
    return (
        f"- AI μ–Έμ–΄λͺ¨λΈμ˜ 이름은 \"CLOVA X\" 이며 λ„€μ΄λ²„μ—μ„œ λ§Œλ“€μ—ˆλ‹€.\n"
        f"- μ˜€λŠ˜μ€ {current_date}이닀.\n"
        f"- μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ— λŒ€ν•΄ μΉœμ ˆν•˜κ³  μžμ„Έν•˜κ²Œ ν•œκ΅­μ–΄λ‘œ λ‹΅λ³€ν•΄μ•Ό ν•œλ‹€."
    )

# --- Warm-up Function ---
def warmup_model():
    if not load_successful or model is None or tokenizer is None:
        print("Skipping warmup: Model not loaded successfully.")
        return

    print("--- Starting Model Warm-up ---")
    try:
        start_warmup_time = time.time()
        warmup_message = "μ•ˆλ…•ν•˜μ„Έμš”"
        system_prompt = get_system_prompt()
        warmup_chat = [
            {"role": "tool_list", "content": ""},
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": warmup_message}
        ]

        inputs = tokenizer.apply_chat_template(
            warmup_chat,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        ).to(device)

        # Check if stop_token_ids_list is empty and handle appropriately
        gen_kwargs = {
            "max_new_tokens": 10,
            "pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
            "do_sample": False
        }
        if stop_token_ids_list:
            gen_kwargs["eos_token_id"] = stop_token_ids_list
        else:
            print("Warmup Warning: No stop tokens defined for generation.")

        with torch.no_grad():
            output_ids = model.generate(**inputs, **gen_kwargs)

        del inputs
        del output_ids
        gc.collect()
        warmup_time = time.time() - start_warmup_time
        print(f"--- Model Warm-up Completed in {warmup_time:.2f} seconds ---")

    except Exception as e:
        print(f"!!! Error during model warm-up: {e}")
    finally:
        gc.collect()

# --- Inference Function with GPU decorator ---
@spaces.GPU()  # Important: Add the spaces.GPU() decorator for ZeroGPU
def predict(message, history):
    """
    Generates response using HyperCLOVAX.
    Assumes 'history' is in the Gradio 'messages' format: List[Dict].
    """
    if model is None or tokenizer is None:
         return "였λ₯˜: λͺ¨λΈμ΄ λ‘œλ“œλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."

    system_prompt = get_system_prompt()

    # Start with system prompt
    chat_history_formatted = [
        {"role": "tool_list", "content": ""}, # As required by model card
        {"role": "system", "content": system_prompt}
    ]

    # Process history based on Gradio ChatInterface format (list of tuples)
    if isinstance(history, list):
        for user_msg, assistant_msg in history:
            chat_history_formatted.append({"role": "user", "content": user_msg})
            if assistant_msg:  # Check if not None or empty
                chat_history_formatted.append({"role": "assistant", "content": assistant_msg})

    # Append the latest user message
    chat_history_formatted.append({"role": "user", "content": message})

    inputs = None
    output_ids = None

    try:
        inputs = tokenizer.apply_chat_template(
            chat_history_formatted,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        ).to(device)
        input_length = inputs['input_ids'].shape[1]
        print(f"\nInput tokens: {input_length}")

    except Exception as e:
        print(f"!!! Error applying chat template: {e}")
        return f"였λ₯˜: μž…λ ₯ ν˜•μ‹μ„ μ²˜λ¦¬ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"

    try:
        print("Generating response...")
        generation_start_time = time.time()

        # Prepare generation arguments, handling empty stop_token_ids_list
        gen_kwargs = {
            "max_new_tokens": MAX_NEW_TOKENS,
            "pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
            "do_sample": True,
            "temperature": 0.7,
            "top_p": 0.9,
        }
        if stop_token_ids_list:
             gen_kwargs["eos_token_id"] = stop_token_ids_list
        else:
             print("Generation Warning: No stop tokens defined.")

        with torch.no_grad():
            output_ids = model.generate(**inputs, **gen_kwargs)

        generation_time = time.time() - generation_start_time
        print(f"Generation complete in {generation_time:.2f} seconds.")

    except Exception as e:
        print(f"!!! Error during model generation: {e}")
        if inputs is not None: del inputs
        if output_ids is not None: del output_ids
        gc.collect()
        return f"였λ₯˜: 응닡을 μƒμ„±ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"

    # Decode the response
    response = "였λ₯˜: 응닡 생성에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€."
    if output_ids is not None:
        try:
            new_tokens = output_ids[0, input_length:]
            response = tokenizer.decode(new_tokens, skip_special_tokens=True)
            print(f"Output tokens: {len(new_tokens)}")
            del new_tokens
        except Exception as e:
            print(f"!!! Error decoding response: {e}")
            response = "였λ₯˜: 응닡을 λ””μ½”λ”©ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€."

    # Clean up memory
    if inputs is not None: del inputs
    if output_ids is not None: del output_ids
    gc.collect()
    print("Memory cleaned.")

    return response

# --- Gradio Interface Setup ---
print("--- Setting up Gradio Interface ---")

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("""
    # NAVER hyperclovax: HyperCLOVAX-SEED-Text-Instruct-0.5B 
    
    """, elem_id="intro-message")
    
    # Using standard ChatInterface (compatible with all Gradio versions)
    chatbot = gr.ChatInterface(
        fn=predict,
        examples=[
            ["넀이버 ν΄λ‘œλ°”XλŠ” λ¬΄μ—‡μΈκ°€μš”?"],
            ["μŠˆλ’°λ”©κ±° 방정식과 μ–‘μžμ—­ν•™μ˜ 관계λ₯Ό μ„€λͺ…ν•΄μ£Όμ„Έμš”."],
            ["λ”₯λŸ¬λ‹ λͺ¨λΈ ν•™μŠ΅ 과정을 λ‹¨κ³„λ³„λ‘œ μ•Œλ €μ€˜."],
            ["μ œμ£Όλ„ μ—¬ν–‰ κ³„νšμ„ μ„Έμš°κ³  μžˆλŠ”λ°, 3λ°• 4일 μΆ”μ²œ μ½”μŠ€ μ’€ μ§œμ€„λž˜?"],
            ["ν•œκ΅­ μ—­μ‚¬μ—μ„œ κ°€μž₯ μ€‘μš”ν•œ 사건 5κ°€μ§€λŠ” λ¬΄μ—‡μΈκ°€μš”?"],
            ["인곡지λŠ₯ μœ€λ¦¬μ— λŒ€ν•΄ μ„€λͺ…ν•΄μ£Όμ„Έμš”."],
        ],
        cache_examples=False,
    )
    
    with gr.Accordion("λͺ¨λΈ 정보", open=False):
        gr.Markdown(f"""
        - **λͺ¨λΈ**: {MODEL_ID}
        - **ν™˜κ²½**: ZeroGPU 곡유 ν™˜κ²½μ—μ„œ μ‹€ν–‰ 쀑
        - **토큰 μ œν•œ**: μ΅œλŒ€ 생성 토큰 μˆ˜λŠ” {MAX_NEW_TOKENS}개둜 μ œν•œλ©λ‹ˆλ‹€.
        - **ν•˜λ“œμ›¨μ–΄**: {"GPU" if device.type == "cuda" else "CPU"} ν™˜κ²½μ—μ„œ μ‹€ν–‰ 쀑
        """)
    
    gr.Markdown(
        "Β© 2025 넀이버 HyperCLOVA X 데λͺ¨ | Powered by Hugging Face & ZeroGPU", 
        elem_classes="footer"
    )

# --- Application Launch ---
if __name__ == "__main__":
    if load_successful:
        warmup_model()
    else:
        print("Skipping warm-up because model loading failed.")

    print("--- Launching Gradio App ---")
    demo.queue().launch(
        # share=True # Uncomment for public link
        server_name="0.0.0.0" # Enable external access
    )