Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import ( | |
LlamaForCausalLM, | |
LlamaTokenizer, | |
GenerationConfig | |
) | |
from peft import PeftModel | |
# ------------------------------------------------------------------------------ | |
# CONFIGURE MODEL & PIPELINE | |
# ------------------------------------------------------------------------------ | |
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" | |
FINETUNED_ADAPTER = "cheberle/autotrain-llama-milch" | |
# Generation hyperparameters | |
DEFAULT_MAX_NEW_TOKENS = 256 | |
DEFAULT_TEMPERATURE = 0.7 | |
DEFAULT_TOP_K = 50 | |
DEFAULT_TOP_P = 0.9 | |
# Load tokenizer from base model | |
tokenizer = LlamaTokenizer.from_pretrained( | |
BASE_MODEL | |
) | |
# Load the base model | |
base_model = LlamaForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map="auto", # Automatically use GPU if available | |
torch_dtype=torch.float16 # Use half-precision to save memory | |
) | |
# Load the PEFT (LoRA) adapter on top of the base model | |
model = PeftModel.from_pretrained( | |
base_model, | |
FINETUNED_ADAPTER, | |
torch_dtype=torch.float16 | |
) | |
model.eval() # put in eval mode | |
# ------------------------------------------------------------------------------ | |
# GENERATION FUNCTION | |
# ------------------------------------------------------------------------------ | |
def generate_text(prompt, | |
max_new_tokens=DEFAULT_MAX_NEW_TOKENS, | |
temperature=DEFAULT_TEMPERATURE, | |
top_k=DEFAULT_TOP_K, | |
top_p=DEFAULT_TOP_P): | |
"""Generate text from the finetuned model using the given parameters.""" | |
# Tokenize the prompt | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Set up generation configuration | |
generation_config = GenerationConfig( | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=1.1, # adjust if needed | |
) | |
# Generate | |
with torch.no_grad(): | |
output_tokens = model.generate( | |
**inputs, | |
generation_config=generation_config | |
) | |
# Decode the generated tokens | |
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True) | |
# Remove the original prompt from the beginning to return only new text | |
if generated_text.startswith(prompt): | |
return generated_text[len(prompt):].strip() | |
else: | |
return generated_text | |
# ------------------------------------------------------------------------------ | |
# GRADIO APP | |
# ------------------------------------------------------------------------------ | |
def clear_inputs(): | |
return "", "" | |
with gr.Blocks(css=".gradio-container {max-width: 800px; margin: auto;}") as demo: | |
gr.Markdown("## DeepSeek R1 Distill-Llama 8B + LoRA from `cheberle/autotrain-llama-milch`") | |
gr.Markdown( | |
"This app uses a base **DeepSeek R1 Distill-Llama 8B** model with " | |
"the **LoRA/PEFT adapter** from [`cheberle/autotrain-llama-milch`].\n\n" | |
"Type in a prompt, adjust generation parameters if you wish, and click 'Generate'." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Ask me anything...", | |
lines=5 | |
) | |
with gr.Accordion("Advanced Generation Settings", open=False): | |
max_new_tokens = gr.Slider( | |
16, 1024, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
step=1, | |
label="Max New Tokens" | |
) | |
temperature = gr.Slider( | |
0.0, 2.0, | |
value=DEFAULT_TEMPERATURE, | |
step=0.1, | |
label="Temperature" | |
) | |
top_k = gr.Slider( | |
0, 100, | |
value=DEFAULT_TOP_K, | |
step=1, | |
label="Top-k" | |
) | |
top_p = gr.Slider( | |
0.0, 1.0, | |
value=DEFAULT_TOP_P, | |
step=0.05, | |
label="Top-p" | |
) | |
generate_btn = gr.Button("Generate", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Model Output", | |
lines=12 | |
) | |
# Button Actions | |
generate_btn.click( | |
fn=generate_text, | |
inputs=[prompt, max_new_tokens, temperature, top_k, top_p], | |
outputs=output | |
) | |
clear_btn.click(fn=clear_inputs, inputs=[], outputs=[prompt, output]) | |
demo.queue(concurrency_count=1) | |
demo.launch() |