File size: 1,728 Bytes
97e8ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from fastapi import FastAPI
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import faiss
import torch
from sentence_transformers import SentenceTransformer

app = FastAPI()

# Load models
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")

# Sample documents
documents = [
    "Startup India provides funding and tax benefits for new startups in India.",
    "Angel investors are individuals who invest in early-stage startups in exchange for equity.",
    "A pitch deck is a presentation that startups use to attract investors.",
    "The government offers startup grants through various schemes.",
    "Networking events connect entrepreneurs with investors and mentors."
]

# Convert documents to embeddings and store in FAISS
doc_vectors = embed_model.encode(documents, convert_to_numpy=True)
dimension = doc_vectors.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(doc_vectors)

def retrieve_docs(query, top_k=2):
    query_vector = embed_model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_vector, top_k)
    return [documents[i] for i in indices[0]]

def generate_response(query):
    retrieved_docs = retrieve_docs(query)
    context = " ".join(retrieved_docs)
    prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    outputs = model.generate(**inputs, max_new_tokens=100)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

@app.get("/chat")
def chat(query: str):
    return {"response": generate_response(query)}