cheberle's picture
f
ec46241
raw
history blame
4.76 kB
import gradio as gr
import torch
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
GenerationConfig
)
from peft import PeftModel
# ------------------------------------------------------------------------------
# CONFIGURE MODEL & PIPELINE
# ------------------------------------------------------------------------------
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
FINETUNED_ADAPTER = "cheberle/autotrain-llama-milch"
# Generation hyperparameters
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_K = 50
DEFAULT_TOP_P = 0.9
# Load tokenizer from base model
tokenizer = LlamaTokenizer.from_pretrained(
BASE_MODEL
)
# Load the base model
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto", # Automatically use GPU if available
torch_dtype=torch.float16 # Use half-precision to save memory
)
# Load the PEFT (LoRA) adapter on top of the base model
model = PeftModel.from_pretrained(
base_model,
FINETUNED_ADAPTER,
torch_dtype=torch.float16
)
model.eval() # put in eval mode
# ------------------------------------------------------------------------------
# GENERATION FUNCTION
# ------------------------------------------------------------------------------
def generate_text(prompt,
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
temperature=DEFAULT_TEMPERATURE,
top_k=DEFAULT_TOP_K,
top_p=DEFAULT_TOP_P):
"""Generate text from the finetuned model using the given parameters."""
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Set up generation configuration
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
repetition_penalty=1.1, # adjust if needed
)
# Generate
with torch.no_grad():
output_tokens = model.generate(
**inputs,
generation_config=generation_config
)
# Decode the generated tokens
generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
# Remove the original prompt from the beginning to return only new text
if generated_text.startswith(prompt):
return generated_text[len(prompt):].strip()
else:
return generated_text
# ------------------------------------------------------------------------------
# GRADIO APP
# ------------------------------------------------------------------------------
def clear_inputs():
return "", ""
with gr.Blocks(css=".gradio-container {max-width: 800px; margin: auto;}") as demo:
gr.Markdown("## DeepSeek R1 Distill-Llama 8B + LoRA from `cheberle/autotrain-llama-milch`")
gr.Markdown(
"This app uses a base **DeepSeek R1 Distill-Llama 8B** model with "
"the **LoRA/PEFT adapter** from [`cheberle/autotrain-llama-milch`].\n\n"
"Type in a prompt, adjust generation parameters if you wish, and click 'Generate'."
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Ask me anything...",
lines=5
)
with gr.Accordion("Advanced Generation Settings", open=False):
max_new_tokens = gr.Slider(
16, 1024,
value=DEFAULT_MAX_NEW_TOKENS,
step=1,
label="Max New Tokens"
)
temperature = gr.Slider(
0.0, 2.0,
value=DEFAULT_TEMPERATURE,
step=0.1,
label="Temperature"
)
top_k = gr.Slider(
0, 100,
value=DEFAULT_TOP_K,
step=1,
label="Top-k"
)
top_p = gr.Slider(
0.0, 1.0,
value=DEFAULT_TOP_P,
step=0.05,
label="Top-p"
)
generate_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column():
output = gr.Textbox(
label="Model Output",
lines=12
)
# Button Actions
generate_btn.click(
fn=generate_text,
inputs=[prompt, max_new_tokens, temperature, top_k, top_p],
outputs=output
)
clear_btn.click(fn=clear_inputs, inputs=[], outputs=[prompt, output])
demo.queue(concurrency_count=1)
demo.launch()