key-life commited on
Commit
97e8ea6
·
verified ·
1 Parent(s): 4b353e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import faiss
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ app = FastAPI()
8
+
9
+ # Load models
10
+ embed_model = SentenceTransformer('all-MiniLM-L6-v2')
11
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
12
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
13
+
14
+ # Sample documents
15
+ documents = [
16
+ "Startup India provides funding and tax benefits for new startups in India.",
17
+ "Angel investors are individuals who invest in early-stage startups in exchange for equity.",
18
+ "A pitch deck is a presentation that startups use to attract investors.",
19
+ "The government offers startup grants through various schemes.",
20
+ "Networking events connect entrepreneurs with investors and mentors."
21
+ ]
22
+
23
+ # Convert documents to embeddings and store in FAISS
24
+ doc_vectors = embed_model.encode(documents, convert_to_numpy=True)
25
+ dimension = doc_vectors.shape[1]
26
+ index = faiss.IndexFlatL2(dimension)
27
+ index.add(doc_vectors)
28
+
29
+ def retrieve_docs(query, top_k=2):
30
+ query_vector = embed_model.encode([query], convert_to_numpy=True)
31
+ distances, indices = index.search(query_vector, top_k)
32
+ return [documents[i] for i in indices[0]]
33
+
34
+ def generate_response(query):
35
+ retrieved_docs = retrieve_docs(query)
36
+ context = " ".join(retrieved_docs)
37
+ prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
38
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
39
+ outputs = model.generate(**inputs, max_new_tokens=100)
40
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
41
+
42
+ @app.get("/chat")
43
+ def chat(query: str):
44
+ return {"response": generate_response(query)}