Hamxa1997's picture
Update app.py
0461ac3 verified
raw
history blame contribute delete
9.65 kB
import gradio as gr
import os
import whisper
import torch
from gtts import gTTS
from sentence_transformers import SentenceTransformer
import faiss
import pandas as pd
from datasets import load_dataset
from deep_translator import GoogleTranslator
from langdetect import detect
#from groq import Groq # Correct import for Groq API
# Set up Whisper with a smaller model or on CPU
model_name = "small" # Use "small", "base", or "medium" for smaller models
whisper_model = whisper.load_model(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model.to(device)
# Initialize the GoogleTranslator from deep-translator
translator = GoogleTranslator(source='auto', target='en')
# Load and prepare the dataset for retrieval
dataset = load_dataset("qgyd2021/e_commerce_customer_service", "faq")
train_dataset = dataset['train']
# Initialize the SentenceTransformer model
embedder = SentenceTransformer('paraphrase-MiniLM-L6-v2')
# Encode the questions from the dataset and set up FAISS
dataset_embeddings = embedder.encode(train_dataset['question'], convert_to_tensor=True)
index = faiss.IndexFlatL2(dataset_embeddings.shape[1]) # Create an index based on L2 distance
index.add(dataset_embeddings.cpu().numpy()) # Add the embeddings to the index
# Set up Groq API with direct API key
api_key = os.getenv("api_key")
import torch
from transformers import pipeline
from langdetect import detect
from deep_translator import GoogleTranslator
from gtts import gTTS
# Initialize the sentiment analysis pipeline
device = 0 if torch.cuda.is_available() else -1
try:
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=device)
except Exception as e:
print(f"Error loading sentiment analysis model: {e}")
# Function to detect the language
def detect_language(text):
try:
return detect(text)
except Exception as e:
print(f"Error during language detection: {e}")
return "en" # Default to English if detection fails
# Translate text using deep-translator
def translate_text(text, dest_lang):
try:
return GoogleTranslator(source='auto', target=dest_lang).translate(text)
except Exception as e:
print(f"Error during translation: {e}")
return text # Return original text if translation fails
# Function to generate a greeting based on sentiment
def generate_greeting(sentiment, lang):
try:
if sentiment == 'NEGATIVE':
if lang in ['ur', 'hi']:
return "پریشان نہ ہوں، میں آپ کی مدد کے لئے یہاں ہوں."
else:
return "Please don't be sad, I'm here to solve your problem."
elif sentiment == 'NEUTRAL':
if lang in ['ur', 'hi']:
return "آپ کا مسئلہ حل کرتے ہیں، آپ فکر نہ کریں."
else:
return "I understand your concern, let's get that sorted out."
elif sentiment == 'POSITIVE':
if lang in ['ur', 'hi']:
return "یہ خوشی کی بات ہے کہ آپ خوش ہیں! آئیں، ہم اسے بہتر بناتے ہیں."
else:
return "I'm glad you're feeling positive! Let's make things even better."
else:
if lang in ['ur', 'hi']:
return "ہیلو! میں آج تمہاری مدد کیسے کر سکتا ہوں؟"
else:
return "Hello! How can I assist you today?"
except Exception as e:
print(f"Error generating greeting: {e}")
return "Hello!"
# Function to transcribe audio using Whisper
def transcribe_audio(audio_path):
try:
result = whisper_model.transcribe(audio_path)
transcription = result['text']
print(f"Transcription result: {transcription}")
return transcription
except Exception as e:
print(f"Error during transcription: {e}")
return "Error during transcription"
# Function to generate a chatbot response based on transcription
def generate_chatbot_response(transcription):
try:
# Detect language of the transcription
detected_language = detect_language(transcription)
# Translate to English if necessary
if detected_language in ['ur', 'hi']:
transcription = translate_text(transcription, 'en')
# Perform sentiment analysis
sentiment_result = sentiment_analyzer(transcription)[0]
sentiment = sentiment_result['label'].upper()
# Generate a greeting based on sentiment
greeting = generate_greeting(sentiment, detected_language)
# Retrieve relevant context using FAISS
transcription_embedding = embedder.encode([transcription], convert_to_tensor=True)
_, indices = index.search(transcription_embedding.cpu().numpy(), k=1)
best_match_index = indices[0][0]
context = train_dataset['answer'][best_match_index]
url = train_dataset['url'][best_match_index]
# Generate the full response
response = f"{greeting}\n\n{context}\n\nPlease visit this link for your query: {url}"
# Translate the response back to Urdu if necessary
if detected_language in ['ur', 'hi']:
response = translate_text(response, 'ur')
return response
except Exception as e:
print(f"Error during chatbot response generation: {e}")
return "Error during response generation"
# Function to convert text to speech using gTTS
def text_to_speech(text, lang='en'):
try:
tts = gTTS(text=text, lang=lang)
tts.save("response.mp3")
return "response.mp3"
except Exception as e:
print(f"Error during text-to-speech conversion: {e}")
return "Error during text-to-speech conversion"
# Main function for Gradio interface
def chatbot(text_input=None, audio_input=None):
if audio_input:
# Step 1: Transcribe audio to text if audio input is provided
transcription = transcribe_audio(audio_input)
input_text = transcription
else:
# Use the text input directly if provided
input_text = text_input
# Step 2: Generate a chatbot response based on the input text
response = generate_chatbot_response(input_text)
# Step 3: Convert the response text to speech if the original input was audio
if audio_input:
lang = 'ur' if detect_language(input_text) in ['ur', 'hi'] else 'en'
audio_path = text_to_speech(response, lang=lang)
return input_text, response, audio_path
else:
return input_text, response, None
# Custom CSS for styling the interface and buttons
custom_css = """
body {
font-family: 'Arial', sans-serif;
background-color: #1e1e1e; /* Black background */
color: white; /* White text */
}
h1 {
font-size: 36px;
color: white;
text-align: center;
margin-bottom: 20px;
}
h2 {
font-size: 24px;
color: white;
text-align: center;
margin-bottom: 10px;
}
.instructions {
font-size: 16px; /* Smaller font size for instructions */
color: #cccccc; /* Light gray color for instructions */
text-align: center;
margin-bottom: 20px;
}
.gradio-container {
background-color: #1e1e1e;
padding: 20px;
border-radius: 10px;
}
.gr-box {
border-radius: 5px;
border: 1px solid #333;
padding: 10px;
margin-bottom: 10px;
}
.gr-button {
border-radius: 5px;
padding: 10px;
font-weight: bold;
font-size: 16px;
transition: background-color 0.3s;
}
.gr-button-submit {
background-color: #28a745; /* Green submit button */
color: white;
}
.gr-button-submit:hover {
background-color: #218838;
}
.gr-button-clear {
background-color: #dc3545; /* Red clear button */
color: white;
}
.gr-button-clear:hover {
background-color: #c82333;
}
.gr-textbox, .gr-audio {
border-radius: 5px;
border: 1px solid #0056b3; /* Blue border */
padding: 8px;
background-color: #2e2e2e;
color: white;
}
.gr-textbox {
background-color: #0056b3; /* Blue background for textboxes */
color: white;
}
.gr-container {
max-width: 900px;
margin: auto;
}
"""
import gradio as gr
# Gradio interface setup with updated CSS
with gr.Blocks(css=custom_css) as iface:
gr.Markdown("<h1>Multilingual Customer Service Chatbot</h1>")
gr.Markdown("<h2>Ask your questions</h2>")
gr.Markdown("<p class='instructions'>If you type in Urdu, it will respond in Urdu. If in English, it will respond in English. Same with voice.</p>")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(lines=2, placeholder="Type your query here...", label="Text Input (Optional)")
audio_input = gr.Audio(type="filepath", label="Audio Input (Optional)")
with gr.Column():
transcription_output = gr.Textbox(label="Transcription") # Add transcription output
response_text = gr.Textbox(label="Chatbot Response")
response_audio = gr.Audio(label="Response Audio (if applicable)")
with gr.Row():
submit_btn = gr.Button("Submit", elem_id="submit-btn", variant="primary")
clear_btn = gr.Button("Clear", elem_id="clear-btn", variant="secondary")
submit_btn.click(chatbot, inputs=[text_input, audio_input], outputs=[transcription_output, response_text, response_audio])
clear_btn.click(lambda: (None, None, None, None, None), inputs=[], outputs=[text_input, audio_input, transcription_output, response_text, response_audio])
# Launch the Gradio interface
iface.launch()