Hemavathineelirothu commited on
Commit
457fe17
·
verified ·
1 Parent(s): 39b27b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -10
app.py CHANGED
@@ -2,17 +2,14 @@ import gradio as gr
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
4
 
5
- # Load dataset
6
  print("Loading dataset...")
7
 
8
  dataset = load_dataset("nazlicanto/persona-based-chat")
9
 
10
- # Choose a base model (DialoGPT)
11
  model_name = "microsoft/DialoGPT-medium"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForCausalLM.from_pretrained(model_name)
14
 
15
- # Ensure pad_token is set
16
  if tokenizer.pad_token is None:
17
  tokenizer.pad_token = tokenizer.eos_token
18
 
@@ -20,10 +17,8 @@ def preprocess_data(batch):
20
  inputs = ["\n".join(dialogue) + "\nBot: " + reference for dialogue, reference in zip(batch["dialogue"], batch["reference"])]
21
  return tokenizer(inputs, truncation=True, padding="max_length", max_length=128)
22
 
23
- # Apply preprocessing
24
  tokenized_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset["train"].column_names)
25
 
26
- # Use validation if available; otherwise, split the train dataset
27
  if "validation" in tokenized_dataset:
28
  train_dataset = tokenized_dataset["train"]
29
  eval_dataset = tokenized_dataset["validation"]
@@ -32,7 +27,6 @@ else:
32
  train_dataset = train_test_split["train"]
33
  eval_dataset = train_test_split["test"]
34
 
35
- # Training arguments
36
  training_args = TrainingArguments(
37
  output_dir="./results",
38
  evaluation_strategy="steps",
@@ -49,7 +43,6 @@ training_args = TrainingArguments(
49
  )
50
 
51
 
52
- # Trainer
53
  trainer = Trainer(
54
  model=model,
55
  args=training_args,
@@ -58,7 +51,6 @@ trainer = Trainer(
58
  tokenizer=tokenizer
59
  )
60
 
61
- # Train model
62
  def train_model():
63
  print("Starting training...")
64
  trainer.train()
@@ -66,7 +58,6 @@ def train_model():
66
  tokenizer.save_pretrained("trained_chatbot")
67
  return "Training Complete!"
68
 
69
- # Chatbot interface
70
  def chatbot(user_input):
71
  input_text = f"User: {user_input}\nBot:"
72
  inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt")
@@ -84,7 +75,6 @@ def chatbot(user_input):
84
  response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
85
  return response.strip()
86
 
87
- # Gradio UI
88
  iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", live=True)
89
 
90
  if __name__ == "__main__":
 
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
4
 
 
5
  print("Loading dataset...")
6
 
7
  dataset = load_dataset("nazlicanto/persona-based-chat")
8
 
 
9
  model_name = "microsoft/DialoGPT-medium"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
12
 
 
13
  if tokenizer.pad_token is None:
14
  tokenizer.pad_token = tokenizer.eos_token
15
 
 
17
  inputs = ["\n".join(dialogue) + "\nBot: " + reference for dialogue, reference in zip(batch["dialogue"], batch["reference"])]
18
  return tokenizer(inputs, truncation=True, padding="max_length", max_length=128)
19
 
 
20
  tokenized_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset["train"].column_names)
21
 
 
22
  if "validation" in tokenized_dataset:
23
  train_dataset = tokenized_dataset["train"]
24
  eval_dataset = tokenized_dataset["validation"]
 
27
  train_dataset = train_test_split["train"]
28
  eval_dataset = train_test_split["test"]
29
 
 
30
  training_args = TrainingArguments(
31
  output_dir="./results",
32
  evaluation_strategy="steps",
 
43
  )
44
 
45
 
 
46
  trainer = Trainer(
47
  model=model,
48
  args=training_args,
 
51
  tokenizer=tokenizer
52
  )
53
 
 
54
  def train_model():
55
  print("Starting training...")
56
  trainer.train()
 
58
  tokenizer.save_pretrained("trained_chatbot")
59
  return "Training Complete!"
60
 
 
61
  def chatbot(user_input):
62
  input_text = f"User: {user_input}\nBot:"
63
  inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt")
 
75
  response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
76
  return response.strip()
77
 
 
78
  iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", live=True)
79
 
80
  if __name__ == "__main__":