safiaa02 commited on
Commit
4d4ad35
·
verified ·
1 Parent(s): 63b8d64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import torch
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
- from peft import PeftModel
10
  from reportlab.lib.pagesizes import A4
11
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
12
  from reportlab.lib.styles import getSampleStyleSheet
@@ -56,21 +55,19 @@ def retrieve_milestone(user_input):
56
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
57
 
58
  # Initialize IBM Granite Model
59
- BASE_NAME = "ibm-granite/granite-3.0-8b-instruct"
60
- LORA_NAME = "ibm-granite/granite-rag-3.0-8b-lora"
61
 
62
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
 
64
  tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
65
  model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
66
- model_rag = PeftModel.from_pretrained(model_base, LORA_NAME)
67
 
68
  def generate_response(user_input, child_age):
69
  relevant_milestone = retrieve_milestone(user_input)
70
  question_chat = [
71
  {
72
  "role": "system",
73
- "content": "{\"instruction\": \"Respond to the user's latest question based solely on the information provided in the documents. Ensure that your response is strictly aligned with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data. Make sure that your response follows the attributes mentioned in the 'meta' field.\", \"documents\": [{\"doc_id\": 1, \"text\": \"The child is {child_age} months old. Based on the given traits: {user_input}, determine whether the child is meeting expected milestones. Relevant milestone: {relevant_milestone}. If there are any concerns, suggest steps the parents can take.\"}], \"meta\": {\"hallucination_tags\": true, \"citations\": true}}"
74
  },
75
  {
76
  "role": "user",
@@ -79,7 +76,7 @@ def generate_response(user_input, child_age):
79
  ]
80
  input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
81
  inputs = tokenizer(input_text, return_tensors="pt")
82
- output = model_rag.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=500)
83
  output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
84
  return output_text
85
 
 
6
  import torch
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
11
  from reportlab.lib.styles import getSampleStyleSheet
 
55
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
56
 
57
  # Initialize IBM Granite Model
58
+ BASE_NAME = "ibm-granite/granite-3.0-2b-base"
 
59
 
60
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
 
62
  tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
63
  model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
 
64
 
65
  def generate_response(user_input, child_age):
66
  relevant_milestone = retrieve_milestone(user_input)
67
  question_chat = [
68
  {
69
  "role": "system",
70
+ "content": f"The child is {child_age} months old. Based on the given traits: {user_input}, determine whether the child is meeting expected milestones. Relevant milestone: {relevant_milestone}. If there are any concerns, suggest steps the parents can take."
71
  },
72
  {
73
  "role": "user",
 
76
  ]
77
  input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
78
  inputs = tokenizer(input_text, return_tensors="pt")
79
+ output = model_base.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=500)
80
  output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
81
  return output_text
82