File size: 2,935 Bytes
ec46241
 
593a8ea
ec46241
 
9df76b8
ec46241
593a8ea
ec46241
9df76b8
593a8ea
 
 
ec46241
 
9df76b8
593a8ea
ec46241
593a8ea
 
 
ec46241
 
9df76b8
ec46241
 
593a8ea
ec46241
 
593a8ea
ec46241
3daea5f
9df76b8
3daea5f
 
 
9df76b8
3daea5f
 
 
 
 
 
 
 
 
 
9df76b8
3daea5f
 
 
 
 
9df76b8
 
ec46241
 
593a8ea
ec46241
4b0ec5b
3daea5f
9df76b8
 
ec46241
 
3daea5f
9df76b8
 
3daea5f
9df76b8
 
 
 
 
3daea5f
593a8ea
9df76b8
 
 
3daea5f
9df76b8
 
 
3daea5f
9df76b8
 
 
3daea5f
 
 
 
 
 
ec46241
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Pfade zu deinem Basismodell und dem feingetunten LoRA-Adapter
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
ADAPTER = "cheberle/autotrain-llama-milch"

print("Lade Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True
)

print("Lade Basismodell...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16
)

print("Lade feingetunten Adapter...")
model = PeftModel.from_pretrained(
    base_model,
    ADAPTER,
    torch_dtype=torch.float16
)
model.eval()

def klassifiziere_lebensmittel_fewshot(produkt_text):
    """
    Verwendet einen Few-Shot-Prompt mit Beispielen auf Deutsch, 
    um das Modell zu einer einzigen, kurzen Lebensmittel-Kategorie 
    ohne zusätzliche Erklärungen zu führen.
    """

    # Beispiele (Few-Shot). 
    # Du kannst die Beispiele anpassen, wenn du andere demonstrieren willst.
    beispiele = (
        "1) Produkt: \"Cailler Branches Milch, 44 x 46 g\"\n   Kategorie: Schokolade\n\n"
        "2) Produkt: \"Aeschbach Trinkschokolade Milch, 1 kg\"\n   Kategorie: Trinkschokolade\n\n"
        "3) Produkt: \"Biedermann Bio Vollmilch 3,8%, pasteurisiert\"\n   Kategorie: Milch\n\n"
    )

    # Prompt mit Few-Shot und neuer Eingabe
    prompt = (
        "Du bist ein Modell zur Klassifikation von Lebensmitteln in deutsche Kategorien.\n"
        "Hier sind einige Beispiele:\n\n"
        f"{beispiele}"
        f"Neues Produkt: \"{produkt_text}\"\n"
        "Kategorie (NUR das Wort und keine Erklärung):"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=200,   # Begrenze die Antwort auf wenige Tokens
            temperature=0.0,     # So wenig "kreatives" Rauschen wie möglich
            top_p=1.0,
            do_sample=False
        )

    # Modell-Antwort dekodieren
    decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()

    # Oft wiederholt das Modell das Prompt - wir nehmen daher nur die letzte Zeile
    lines = decoded.split("\n")
    label = lines[-1].strip()

    return label

# Gradio-Interface
with gr.Blocks() as demo:
    produkt_box = gr.Textbox(
        lines=2,
        label="Produktbeschreibung",
        placeholder="z.B. 'Biedermann Bio Jogurt Schafmilch Himbeer, 5 x 120 g'"
    )
    output_box = gr.Textbox(
        lines=1,
        label="Predizierte Kategorie",
        placeholder="Hier erscheint das Ergebnis"
    )

    classify_button = gr.Button("Kategorie bestimmen (Few-Shot)")
    classify_button.click(
        fn=klassifiziere_lebensmittel_fewshot,
        inputs=produkt_box,
        outputs=output_box
    )

demo.launch()