Spaces:
Sleeping
Sleeping
File size: 2,935 Bytes
ec46241 593a8ea ec46241 9df76b8 ec46241 593a8ea ec46241 9df76b8 593a8ea ec46241 9df76b8 593a8ea ec46241 593a8ea ec46241 9df76b8 ec46241 593a8ea ec46241 593a8ea ec46241 3daea5f 9df76b8 3daea5f 9df76b8 3daea5f 9df76b8 3daea5f 9df76b8 ec46241 593a8ea ec46241 4b0ec5b 3daea5f 9df76b8 ec46241 3daea5f 9df76b8 3daea5f 9df76b8 3daea5f 593a8ea 9df76b8 3daea5f 9df76b8 3daea5f 9df76b8 3daea5f ec46241 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Pfade zu deinem Basismodell und dem feingetunten LoRA-Adapter
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
ADAPTER = "cheberle/autotrain-llama-milch"
print("Lade Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True
)
print("Lade Basismodell...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
)
print("Lade feingetunten Adapter...")
model = PeftModel.from_pretrained(
base_model,
ADAPTER,
torch_dtype=torch.float16
)
model.eval()
def klassifiziere_lebensmittel_fewshot(produkt_text):
"""
Verwendet einen Few-Shot-Prompt mit Beispielen auf Deutsch,
um das Modell zu einer einzigen, kurzen Lebensmittel-Kategorie
ohne zusätzliche Erklärungen zu führen.
"""
# Beispiele (Few-Shot).
# Du kannst die Beispiele anpassen, wenn du andere demonstrieren willst.
beispiele = (
"1) Produkt: \"Cailler Branches Milch, 44 x 46 g\"\n Kategorie: Schokolade\n\n"
"2) Produkt: \"Aeschbach Trinkschokolade Milch, 1 kg\"\n Kategorie: Trinkschokolade\n\n"
"3) Produkt: \"Biedermann Bio Vollmilch 3,8%, pasteurisiert\"\n Kategorie: Milch\n\n"
)
# Prompt mit Few-Shot und neuer Eingabe
prompt = (
"Du bist ein Modell zur Klassifikation von Lebensmitteln in deutsche Kategorien.\n"
"Hier sind einige Beispiele:\n\n"
f"{beispiele}"
f"Neues Produkt: \"{produkt_text}\"\n"
"Kategorie (NUR das Wort und keine Erklärung):"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200, # Begrenze die Antwort auf wenige Tokens
temperature=0.0, # So wenig "kreatives" Rauschen wie möglich
top_p=1.0,
do_sample=False
)
# Modell-Antwort dekodieren
decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
# Oft wiederholt das Modell das Prompt - wir nehmen daher nur die letzte Zeile
lines = decoded.split("\n")
label = lines[-1].strip()
return label
# Gradio-Interface
with gr.Blocks() as demo:
produkt_box = gr.Textbox(
lines=2,
label="Produktbeschreibung",
placeholder="z.B. 'Biedermann Bio Jogurt Schafmilch Himbeer, 5 x 120 g'"
)
output_box = gr.Textbox(
lines=1,
label="Predizierte Kategorie",
placeholder="Hier erscheint das Ergebnis"
)
classify_button = gr.Button("Kategorie bestimmen (Few-Shot)")
classify_button.click(
fn=klassifiziere_lebensmittel_fewshot,
inputs=produkt_box,
outputs=output_box
)
demo.launch() |