Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
print("Loading dataset...") | |
dataset = load_dataset("nazlicanto/persona-based-chat") | |
model_name = "microsoft/DialoGPT-medium" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
def preprocess_data(batch): | |
inputs = ["\n".join(dialogue) + "\nBot: " + reference for dialogue, reference in zip(batch["dialogue"], batch["reference"])] | |
return tokenizer(inputs, truncation=True, padding="max_length", max_length=128) | |
tokenized_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset["train"].column_names) | |
if "validation" in tokenized_dataset: | |
train_dataset = tokenized_dataset["train"] | |
eval_dataset = tokenized_dataset["validation"] | |
else: | |
train_test_split = tokenized_dataset["train"].train_test_split(test_size=0.1) | |
train_dataset = train_test_split["train"] | |
eval_dataset = train_test_split["test"] | |
training_args = TrainingArguments( | |
output_dir="./results", | |
evaluation_strategy="steps", | |
eval_steps=200, | |
per_device_train_batch_size=4, | |
per_device_eval_batch_size=4, | |
save_strategy="steps", | |
save_steps=500, | |
logging_steps=100, | |
learning_rate=5e-5, | |
num_train_epochs=5, | |
warmup_steps=200, | |
fp16=True | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
tokenizer=tokenizer | |
) | |
def train_model(): | |
print("Starting training...") | |
trainer.train() | |
model.save_pretrained("trained_chatbot") | |
tokenizer.save_pretrained("trained_chatbot") | |
return "Training Complete!" | |
def chatbot(user_input): | |
input_text = f"User: {user_input}\nBot:" | |
inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt") | |
outputs = model.generate( | |
inputs, | |
max_length=200, | |
pad_token_id=tokenizer.eos_token_id, | |
temperature=0.9, # Increase for more diverse outputs | |
top_p=0.95, # Use nucleus sampling | |
repetition_penalty=1.2, # Avoid repetitive outputs | |
do_sample=True | |
) | |
response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True) | |
return response.strip() | |
iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", live=True) | |
if __name__ == "__main__": | |
print("Launching Gradio interface...") | |
iface.launch() | |