Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
+
import faiss
|
4 |
+
import torch
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
|
7 |
+
app = FastAPI()
|
8 |
+
|
9 |
+
# Load models
|
10 |
+
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
12 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
13 |
+
|
14 |
+
# Sample documents
|
15 |
+
documents = [
|
16 |
+
"Startup India provides funding and tax benefits for new startups in India.",
|
17 |
+
"Angel investors are individuals who invest in early-stage startups in exchange for equity.",
|
18 |
+
"A pitch deck is a presentation that startups use to attract investors.",
|
19 |
+
"The government offers startup grants through various schemes.",
|
20 |
+
"Networking events connect entrepreneurs with investors and mentors."
|
21 |
+
]
|
22 |
+
|
23 |
+
# Convert documents to embeddings and store in FAISS
|
24 |
+
doc_vectors = embed_model.encode(documents, convert_to_numpy=True)
|
25 |
+
dimension = doc_vectors.shape[1]
|
26 |
+
index = faiss.IndexFlatL2(dimension)
|
27 |
+
index.add(doc_vectors)
|
28 |
+
|
29 |
+
def retrieve_docs(query, top_k=2):
|
30 |
+
query_vector = embed_model.encode([query], convert_to_numpy=True)
|
31 |
+
distances, indices = index.search(query_vector, top_k)
|
32 |
+
return [documents[i] for i in indices[0]]
|
33 |
+
|
34 |
+
def generate_response(query):
|
35 |
+
retrieved_docs = retrieve_docs(query)
|
36 |
+
context = " ".join(retrieved_docs)
|
37 |
+
prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
|
38 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
39 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
40 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
41 |
+
|
42 |
+
@app.get("/chat")
|
43 |
+
def chat(query: str):
|
44 |
+
return {"response": generate_response(query)}
|