cheberle commited on
Commit
593a8ea
·
1 Parent(s): ec46241
Files changed (1) hide show
  1. app.py +27 -124
app.py CHANGED
@@ -1,147 +1,50 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import (
4
- LlamaForCausalLM,
5
- LlamaTokenizer,
6
- GenerationConfig
7
- )
8
  from peft import PeftModel
9
 
10
- # ------------------------------------------------------------------------------
11
- # CONFIGURE MODEL & PIPELINE
12
- # ------------------------------------------------------------------------------
13
  BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
14
- FINETUNED_ADAPTER = "cheberle/autotrain-llama-milch"
15
-
16
- # Generation hyperparameters
17
- DEFAULT_MAX_NEW_TOKENS = 256
18
- DEFAULT_TEMPERATURE = 0.7
19
- DEFAULT_TOP_K = 50
20
- DEFAULT_TOP_P = 0.9
21
 
22
- # Load tokenizer from base model
23
- tokenizer = LlamaTokenizer.from_pretrained(
24
- BASE_MODEL
 
25
  )
26
 
27
- # Load the base model
28
- base_model = LlamaForCausalLM.from_pretrained(
29
  BASE_MODEL,
30
- device_map="auto", # Automatically use GPU if available
31
- torch_dtype=torch.float16 # Use half-precision to save memory
 
32
  )
33
 
34
- # Load the PEFT (LoRA) adapter on top of the base model
35
  model = PeftModel.from_pretrained(
36
  base_model,
37
- FINETUNED_ADAPTER,
38
  torch_dtype=torch.float16
39
  )
 
40
 
41
- model.eval() # put in eval mode
42
-
43
- # ------------------------------------------------------------------------------
44
- # GENERATION FUNCTION
45
- # ------------------------------------------------------------------------------
46
- def generate_text(prompt,
47
- max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
48
- temperature=DEFAULT_TEMPERATURE,
49
- top_k=DEFAULT_TOP_K,
50
- top_p=DEFAULT_TOP_P):
51
- """Generate text from the finetuned model using the given parameters."""
52
- # Tokenize the prompt
53
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
54
-
55
- # Set up generation configuration
56
- generation_config = GenerationConfig(
57
- max_new_tokens=max_new_tokens,
58
- temperature=temperature,
59
- top_k=top_k,
60
- top_p=top_p,
61
- do_sample=True,
62
- repetition_penalty=1.1, # adjust if needed
63
- )
64
-
65
- # Generate
66
  with torch.no_grad():
67
- output_tokens = model.generate(
68
  **inputs,
69
- generation_config=generation_config
 
 
 
 
70
  )
71
-
72
- # Decode the generated tokens
73
- generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
74
-
75
- # Remove the original prompt from the beginning to return only new text
76
- if generated_text.startswith(prompt):
77
- return generated_text[len(prompt):].strip()
78
- else:
79
- return generated_text
80
-
81
- # ------------------------------------------------------------------------------
82
- # GRADIO APP
83
- # ------------------------------------------------------------------------------
84
- def clear_inputs():
85
- return "", ""
86
-
87
- with gr.Blocks(css=".gradio-container {max-width: 800px; margin: auto;}") as demo:
88
- gr.Markdown("## DeepSeek R1 Distill-Llama 8B + LoRA from `cheberle/autotrain-llama-milch`")
89
- gr.Markdown(
90
- "This app uses a base **DeepSeek R1 Distill-Llama 8B** model with "
91
- "the **LoRA/PEFT adapter** from [`cheberle/autotrain-llama-milch`].\n\n"
92
- "Type in a prompt, adjust generation parameters if you wish, and click 'Generate'."
93
- )
94
 
95
- with gr.Row():
96
- with gr.Column():
97
- prompt = gr.Textbox(
98
- label="Prompt",
99
- placeholder="Ask me anything...",
100
- lines=5
101
- )
102
- with gr.Accordion("Advanced Generation Settings", open=False):
103
- max_new_tokens = gr.Slider(
104
- 16, 1024,
105
- value=DEFAULT_MAX_NEW_TOKENS,
106
- step=1,
107
- label="Max New Tokens"
108
- )
109
- temperature = gr.Slider(
110
- 0.0, 2.0,
111
- value=DEFAULT_TEMPERATURE,
112
- step=0.1,
113
- label="Temperature"
114
- )
115
- top_k = gr.Slider(
116
- 0, 100,
117
- value=DEFAULT_TOP_K,
118
- step=1,
119
- label="Top-k"
120
- )
121
- top_p = gr.Slider(
122
- 0.0, 1.0,
123
- value=DEFAULT_TOP_P,
124
- step=0.05,
125
- label="Top-p"
126
- )
127
-
128
- generate_btn = gr.Button("Generate", variant="primary")
129
- clear_btn = gr.Button("Clear")
130
-
131
- with gr.Column():
132
- output = gr.Textbox(
133
- label="Model Output",
134
- lines=12
135
- )
136
-
137
- # Button Actions
138
- generate_btn.click(
139
- fn=generate_text,
140
- inputs=[prompt, max_new_tokens, temperature, top_k, top_p],
141
- outputs=output
142
- )
143
-
144
- clear_btn.click(fn=clear_inputs, inputs=[], outputs=[prompt, output])
145
 
146
- demo.queue(concurrency_count=1)
147
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
4
  from peft import PeftModel
5
 
 
 
 
6
  BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
7
+ ADAPTER = "cheberle/autotrain-llama-milch"
 
 
 
 
 
 
8
 
9
+ print("Loading tokenizer...")
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ BASE_MODEL,
12
+ trust_remote_code=True
13
  )
14
 
15
+ print("Loading base model...")
16
+ base_model = AutoModelForCausalLM.from_pretrained(
17
  BASE_MODEL,
18
+ trust_remote_code=True,
19
+ device_map="auto",
20
+ torch_dtype=torch.float16
21
  )
22
 
23
+ print("Loading finetuned adapter...")
24
  model = PeftModel.from_pretrained(
25
  base_model,
26
+ ADAPTER,
27
  torch_dtype=torch.float16
28
  )
29
+ model.eval()
30
 
31
+ def generate_text(prompt):
 
 
 
 
 
 
 
 
 
 
 
32
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
33
  with torch.no_grad():
34
+ output = model.generate(
35
  **inputs,
36
+ max_new_tokens=128,
37
+ temperature=0.7,
38
+ top_p=0.9,
39
+ top_k=50,
40
+ do_sample=True
41
  )
42
+ return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ with gr.Blocks() as demo:
45
+ prompt_box = gr.Textbox(lines=4, label="Prompt")
46
+ output_box = gr.Textbox(lines=6, label="Output")
47
+ btn = gr.Button("Generate")
48
+ btn.click(fn=generate_text, inputs=prompt_box, outputs=output_box)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
50
  demo.launch()