Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ import numpy as np
|
|
6 |
import torch
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
-
from peft import PeftModel
|
10 |
from reportlab.lib.pagesizes import A4
|
11 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
12 |
from reportlab.lib.styles import getSampleStyleSheet
|
@@ -56,21 +55,19 @@ def retrieve_milestone(user_input):
|
|
56 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
57 |
|
58 |
# Initialize IBM Granite Model
|
59 |
-
BASE_NAME = "ibm-granite/granite-3.0-
|
60 |
-
LORA_NAME = "ibm-granite/granite-rag-3.0-8b-lora"
|
61 |
|
62 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
63 |
|
64 |
tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
|
65 |
model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
|
66 |
-
model_rag = PeftModel.from_pretrained(model_base, LORA_NAME)
|
67 |
|
68 |
def generate_response(user_input, child_age):
|
69 |
relevant_milestone = retrieve_milestone(user_input)
|
70 |
question_chat = [
|
71 |
{
|
72 |
"role": "system",
|
73 |
-
"content": "
|
74 |
},
|
75 |
{
|
76 |
"role": "user",
|
@@ -79,7 +76,7 @@ def generate_response(user_input, child_age):
|
|
79 |
]
|
80 |
input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
|
81 |
inputs = tokenizer(input_text, return_tensors="pt")
|
82 |
-
output =
|
83 |
output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
84 |
return output_text
|
85 |
|
|
|
6 |
import torch
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
9 |
from reportlab.lib.pagesizes import A4
|
10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
11 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
|
55 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
56 |
|
57 |
# Initialize IBM Granite Model
|
58 |
+
BASE_NAME = "ibm-granite/granite-3.0-2b-base"
|
|
|
59 |
|
60 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
61 |
|
62 |
tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
|
63 |
model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
|
|
|
64 |
|
65 |
def generate_response(user_input, child_age):
|
66 |
relevant_milestone = retrieve_milestone(user_input)
|
67 |
question_chat = [
|
68 |
{
|
69 |
"role": "system",
|
70 |
+
"content": f"The child is {child_age} months old. Based on the given traits: {user_input}, determine whether the child is meeting expected milestones. Relevant milestone: {relevant_milestone}. If there are any concerns, suggest steps the parents can take."
|
71 |
},
|
72 |
{
|
73 |
"role": "user",
|
|
|
76 |
]
|
77 |
input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
|
78 |
inputs = tokenizer(input_text, return_tensors="pt")
|
79 |
+
output = model_base.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=500)
|
80 |
output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
81 |
return output_text
|
82 |
|