Update app.py
Browse files
app.py
CHANGED
@@ -368,7 +368,7 @@ model_bi_encoder.bert_model.from_pretrained("models/friends_bi_encoder")
|
|
368 |
# Load question embeds
|
369 |
question_embeds = np.load("bi_bert_question.npy")
|
370 |
|
371 |
-
def chat_bi_bert(question):
|
372 |
question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
373 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
374 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
@@ -398,7 +398,7 @@ class CrossEncoderBert(torch.nn.Module):
|
|
398 |
model_cross_encoder = CrossEncoderBert().to(device)
|
399 |
model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
|
400 |
|
401 |
-
def chat_cross_bert(question):
|
402 |
|
403 |
question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
404 |
cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
|
@@ -442,11 +442,11 @@ def echo(message, history, model):
|
|
442 |
return answer
|
443 |
|
444 |
elif model=="Bi-BERT-Encoder":
|
445 |
-
answer = chat_bi_bert(message)
|
446 |
return answer
|
447 |
|
448 |
elif model=="Bi+Cross-BERT-Encoder":
|
449 |
-
answer = chat_cross_bert(message)
|
450 |
return answer
|
451 |
|
452 |
|
|
|
368 |
# Load question embeds
|
369 |
question_embeds = np.load("bi_bert_question.npy")
|
370 |
|
371 |
+
def chat_bi_bert(question, history):
|
372 |
question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
373 |
cosine_similarities = cosine_similarity([question], question_embeds).flatten()
|
374 |
top_indice = np.argmax(cosine_similarities, axis=0)
|
|
|
398 |
model_cross_encoder = CrossEncoderBert().to(device)
|
399 |
model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
|
400 |
|
401 |
+
def chat_cross_bert(question, history):
|
402 |
|
403 |
question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
|
404 |
cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
|
|
|
442 |
return answer
|
443 |
|
444 |
elif model=="Bi-BERT-Encoder":
|
445 |
+
answer = chat_bi_bert(message, history)
|
446 |
return answer
|
447 |
|
448 |
elif model=="Bi+Cross-BERT-Encoder":
|
449 |
+
answer = chat_cross_bert(message, history)
|
450 |
return answer
|
451 |
|
452 |
|