cheberle commited on
Commit
9df76b8
·
1 Parent(s): 593a8ea
Files changed (1) hide show
  1. app.py +44 -14
app.py CHANGED
@@ -3,16 +3,17 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
 
 
6
  BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
7
  ADAPTER = "cheberle/autotrain-llama-milch"
8
 
9
- print("Loading tokenizer...")
10
  tokenizer = AutoTokenizer.from_pretrained(
11
  BASE_MODEL,
12
  trust_remote_code=True
13
  )
14
 
15
- print("Loading base model...")
16
  base_model = AutoModelForCausalLM.from_pretrained(
17
  BASE_MODEL,
18
  trust_remote_code=True,
@@ -20,7 +21,7 @@ base_model = AutoModelForCausalLM.from_pretrained(
20
  torch_dtype=torch.float16
21
  )
22
 
23
- print("Loading finetuned adapter...")
24
  model = PeftModel.from_pretrained(
25
  base_model,
26
  ADAPTER,
@@ -28,23 +29,52 @@ model = PeftModel.from_pretrained(
28
  )
29
  model.eval()
30
 
31
- def generate_text(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
32
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
  with torch.no_grad():
34
  output = model.generate(
35
  **inputs,
36
- max_new_tokens=128,
37
- temperature=0.7,
38
- top_p=0.9,
39
- top_k=50,
40
- do_sample=True
41
  )
42
- return tokenizer.decode(output[0], skip_special_tokens=True)
43
 
 
 
 
 
 
 
 
 
 
 
44
  with gr.Blocks() as demo:
45
- prompt_box = gr.Textbox(lines=4, label="Prompt")
46
- output_box = gr.Textbox(lines=6, label="Output")
47
- btn = gr.Button("Generate")
48
- btn.click(fn=generate_text, inputs=prompt_box, outputs=output_box)
 
 
 
 
 
 
 
 
 
49
 
50
  demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
 
6
+ # Pfade zu deinem Basismodell und dem feingetunten LoRA-Adapter
7
  BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
8
  ADAPTER = "cheberle/autotrain-llama-milch"
9
 
10
+ print("Lade Tokenizer...")
11
  tokenizer = AutoTokenizer.from_pretrained(
12
  BASE_MODEL,
13
  trust_remote_code=True
14
  )
15
 
16
+ print("Lade Basismodell...")
17
  base_model = AutoModelForCausalLM.from_pretrained(
18
  BASE_MODEL,
19
  trust_remote_code=True,
 
21
  torch_dtype=torch.float16
22
  )
23
 
24
+ print("Lade feingetunten Adapter...")
25
  model = PeftModel.from_pretrained(
26
  base_model,
27
  ADAPTER,
 
29
  )
30
  model.eval()
31
 
32
+ def klassifiziere_lebensmittel(produkt_text):
33
+ """
34
+ Diese Funktion erstellt ein Prompt auf Deutsch, das das Modell anweist,
35
+ eine Lebensmittel-Kategorie (als einzelnes Label) für den eingegebenen
36
+ Produkttext zurückzugeben.
37
+ Temperatur=0.0 und do_sample=False sorgen für deterministischen Output.
38
+ """
39
+ prompt = (
40
+ f"Du bist ein Modell zur Klassifikation von Lebensmitteln. "
41
+ f"Analysiere die Produktbeschreibung auf Deutsch: \"{produkt_text}\".\n"
42
+ f"Gib bitte nur eine einzige passende Lebensmittel-Kategorie (auf Deutsch) zurück."
43
+ )
44
+
45
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
  with torch.no_grad():
47
  output = model.generate(
48
  **inputs,
49
+ max_new_tokens=30, # Begrenze die Länge des Modell-Antwort
50
+ temperature=0.0, # Keine "kreativen" Abweichungen
51
+ top_p=1.0,
52
+ do_sample=False
 
53
  )
 
54
 
55
+ # Ausgabe dekodieren und bereinigen
56
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
57
+
58
+ # Falls das Modell mehrzeiligen Text ausgibt, nehmen wir die letzte Zeile
59
+ lines = decoded.split("\n")
60
+ label = lines[-1].strip()
61
+
62
+ return label
63
+
64
+ # Gradio-Interface aufbauen
65
  with gr.Blocks() as demo:
66
+ produkt_box = gr.Textbox(
67
+ lines=2,
68
+ label="Produktbeschreibung",
69
+ placeholder="z.B. 'Aeschbach Trinkschokolade Milch, 1 kg'"
70
+ )
71
+ output_box = gr.Textbox(
72
+ lines=1,
73
+ label="Prediziertes Lebensmittel-Label",
74
+ placeholder="Hier erscheint das Ergebnis"
75
+ )
76
+
77
+ classify_button = gr.Button("Kategorie bestimmen")
78
+ classify_button.click(fn=klassifiziere_lebensmittel, inputs=produkt_box, outputs=output_box)
79
 
80
  demo.launch()