alinasrinazif commited on
Commit
1580220
·
verified ·
1 Parent(s): b172b5a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ # Load model and tokenizer
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ "universitytehran/PersianMind-v1.0",
9
+ torch_dtype=torch.bfloat16,
10
+ low_cpu_mem_usage=True,
11
+ device_map={"": device},
12
+ )
13
+ tokenizer = AutoTokenizer.from_pretrained("universitytehran/PersianMind-v1.0")
14
+
15
+ # Conversation template
16
+ TEMPLATE = "{context}\nYou: {prompt}\nPersianMind: "
17
+ CONTEXT = "This is a conversation with PersianMind. It is an artificial intelligence model designed by a team of " \
18
+ "NLP experts at the University of Tehran to help you with various tasks such as answering questions, " \
19
+ "providing recommendations, and helping with decision making. You can ask it anything you want and " \
20
+ "it will do its best to give you accurate and relevant information."
21
+
22
+ # Streamlit app
23
+ st.title("PersianMind Chat")
24
+ st.markdown("Chat with **PersianMind**, an AI model by the University of Tehran.")
25
+
26
+ # User input
27
+ prompt = st.text_input("Enter your question (in Persian):")
28
+
29
+ if st.button("Get Response"):
30
+ if prompt.strip():
31
+ with st.spinner("Generating response..."):
32
+ model_input = TEMPLATE.format(context=CONTEXT, prompt=prompt)
33
+ input_tokens = tokenizer(model_input, return_tensors="pt").to(device)
34
+ generate_ids = model.generate(**input_tokens, max_new_tokens=512, do_sample=False, repetition_penalty=1.1)
35
+ model_output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
36
+ response = model_output[len(model_input):]
37
+
38
+ st.text_area("PersianMind's Response:", response, height=200)
39
+ else:
40
+ st.warning("Please enter a question.")