Ajit Panday commited on
Commit
0680865
·
1 Parent(s): d538a8c

Initial commit: Customer Support Chatbot with DialoGPT-medium

Browse files
Files changed (3) hide show
  1. app.py +16 -7
  2. requirements.txt +4 -1
  3. train.py +87 -0
app.py CHANGED
@@ -3,18 +3,25 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  from datasets import load_dataset
5
  import random
 
6
 
7
- # Load the model and tokenizer
8
- model_name = "microsoft/DialoGPT-medium"
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForCausalLM.from_pretrained(model_name)
11
 
12
  # Load the customer support dataset
13
  dataset = load_dataset("Victorano/customer-support-1k")
14
 
15
  def generate_response(message, history):
16
- # Encode the user message
17
- input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors='pt')
 
 
 
 
 
 
18
 
19
  # Generate response
20
  with torch.no_grad():
@@ -31,13 +38,15 @@ def generate_response(message, history):
31
 
32
  # Decode and return the response
33
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
34
  return response
35
 
36
  # Create the Gradio interface
37
  with gr.Blocks(css="footer {display: none !important}") as demo:
38
  gr.Markdown("""
39
  # 🤖 Customer Support Chatbot
40
- This chatbot is powered by DialoGPT-medium and trained on customer support conversations.
41
  """)
42
 
43
  chatbot = gr.Chatbot(
 
3
  import torch
4
  from datasets import load_dataset
5
  import random
6
+ import os
7
 
8
+ # Check if fine-tuned model exists, otherwise use base model
9
+ model_path = "./customer_support_chatbot" if os.path.exists("./customer_support_chatbot") else "microsoft/DialoGPT-medium"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
11
+ model = AutoModelForCausalLM.from_pretrained(model_path)
12
 
13
  # Load the customer support dataset
14
  dataset = load_dataset("Victorano/customer-support-1k")
15
 
16
  def generate_response(message, history):
17
+ # Format the input with conversation history
18
+ conversation = ""
19
+ for user_msg, bot_msg in history:
20
+ conversation += f"Customer: {user_msg}\nSupport: {bot_msg}\n"
21
+ conversation += f"Customer: {message}\nSupport:"
22
+
23
+ # Encode the conversation
24
+ input_ids = tokenizer.encode(conversation, return_tensors='pt')
25
 
26
  # Generate response
27
  with torch.no_grad():
 
38
 
39
  # Decode and return the response
40
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
41
+ # Extract only the last response (after "Support:")
42
+ response = response.split("Support:")[-1].strip()
43
  return response
44
 
45
  # Create the Gradio interface
46
  with gr.Blocks(css="footer {display: none !important}") as demo:
47
  gr.Markdown("""
48
  # 🤖 Customer Support Chatbot
49
+ This chatbot is fine-tuned on customer support conversations using DialoGPT-medium.
50
  """)
51
 
52
  chatbot = gr.Chatbot(
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  gradio==4.19.2
2
  transformers==4.37.2
3
  torch==2.2.0
4
- datasets==2.17.1
 
 
 
 
1
  gradio==4.19.2
2
  transformers==4.37.2
3
  torch==2.2.0
4
+ datasets==2.17.1
5
+ accelerate==0.27.2
6
+ evaluate==0.4.1
7
+ scikit-learn==1.4.0
train.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
3
+ from datasets import load_dataset
4
+ import numpy as np
5
+ from typing import Dict, List
6
+ import os
7
+
8
+ def load_and_prepare_data():
9
+ # Load the dataset
10
+ dataset = load_dataset("Victorano/customer-support-1k")
11
+
12
+ # Load tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
14
+
15
+ # Function to format conversations
16
+ def format_conversation(example):
17
+ # Combine question and answer into a single conversation
18
+ conversation = f"Customer: {example['question']}\nSupport: {example['answer']}"
19
+ return {"text": conversation}
20
+
21
+ # Apply formatting to both train and test sets
22
+ formatted_dataset = dataset.map(
23
+ format_conversation,
24
+ remove_columns=dataset["train"].column_names
25
+ )
26
+
27
+ # Tokenize the dataset
28
+ def tokenize_function(examples):
29
+ return tokenizer(
30
+ examples["text"],
31
+ padding="max_length",
32
+ truncation=True,
33
+ max_length=512,
34
+ return_tensors="pt"
35
+ )
36
+
37
+ tokenized_dataset = formatted_dataset.map(
38
+ tokenize_function,
39
+ batched=True,
40
+ remove_columns=formatted_dataset["train"].column_names
41
+ )
42
+
43
+ return tokenized_dataset, tokenizer
44
+
45
+ def train_model():
46
+ # Load and prepare data
47
+ tokenized_dataset, tokenizer = load_and_prepare_data()
48
+
49
+ # Load model
50
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
51
+
52
+ # Define training arguments
53
+ training_args = TrainingArguments(
54
+ output_dir="./customer_support_chatbot",
55
+ num_train_epochs=3,
56
+ per_device_train_batch_size=4,
57
+ per_device_eval_batch_size=4,
58
+ warmup_steps=500,
59
+ weight_decay=0.01,
60
+ logging_dir="./logs",
61
+ logging_steps=100,
62
+ save_strategy="epoch",
63
+ evaluation_strategy="epoch",
64
+ load_best_model_at_end=True,
65
+ push_to_hub=False,
66
+ )
67
+
68
+ # Initialize trainer
69
+ trainer = Trainer(
70
+ model=model,
71
+ args=training_args,
72
+ train_dataset=tokenized_dataset["train"],
73
+ eval_dataset=tokenized_dataset["test"],
74
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
75
+ )
76
+
77
+ # Train the model
78
+ trainer.train()
79
+
80
+ # Save the model and tokenizer
81
+ model.save_pretrained("./customer_support_chatbot")
82
+ tokenizer.save_pretrained("./customer_support_chatbot")
83
+
84
+ print("Training completed! Model saved to ./customer_support_chatbot")
85
+
86
+ if __name__ == "__main__":
87
+ train_model()