StKirill commited on
Commit
b8b02a2
·
verified ·
1 Parent(s): 6a3f0d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -368,7 +368,7 @@ model_bi_encoder.bert_model.from_pretrained("models/friends_bi_encoder")
368
  # Load question embeds
369
  question_embeds = np.load("bi_bert_question.npy")
370
 
371
- def chat_bi_bert(question):
372
  question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
373
  cosine_similarities = cosine_similarity([question], question_embeds).flatten()
374
  top_indice = np.argmax(cosine_similarities, axis=0)
@@ -398,7 +398,7 @@ class CrossEncoderBert(torch.nn.Module):
398
  model_cross_encoder = CrossEncoderBert().to(device)
399
  model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
400
 
401
- def chat_cross_bert(question):
402
 
403
  question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
404
  cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
@@ -442,11 +442,11 @@ def echo(message, history, model):
442
  return answer
443
 
444
  elif model=="Bi-BERT-Encoder":
445
- answer = chat_bi_bert(message)
446
  return answer
447
 
448
  elif model=="Bi+Cross-BERT-Encoder":
449
- answer = chat_cross_bert(message)
450
  return answer
451
 
452
 
 
368
  # Load question embeds
369
  question_embeds = np.load("bi_bert_question.npy")
370
 
371
+ def chat_bi_bert(question, history):
372
  question = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
373
  cosine_similarities = cosine_similarity([question], question_embeds).flatten()
374
  top_indice = np.argmax(cosine_similarities, axis=0)
 
398
  model_cross_encoder = CrossEncoderBert().to(device)
399
  model_cross_encoder.bert_model.from_pretrained("models/friends_cross_encoder")
400
 
401
+ def chat_cross_bert(question, history):
402
 
403
  question_encoded = encode(question, model_bi_encoder.bert_tokenizer, model_bi_encoder.bert_model, device).squeeze().cpu().detach().numpy()
404
  cosine_similarities = cosine_similarity([question_encoded], question_embeds).flatten()
 
442
  return answer
443
 
444
  elif model=="Bi-BERT-Encoder":
445
+ answer = chat_bi_bert(message, history)
446
  return answer
447
 
448
  elif model=="Bi+Cross-BERT-Encoder":
449
+ answer = chat_cross_bert(message, history)
450
  return answer
451
 
452