AI_digitaltwin / app.py
Hemavathineelirothu's picture
Update app.py
457fe17 verified
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
print("Loading dataset...")
dataset = load_dataset("nazlicanto/persona-based-chat")
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def preprocess_data(batch):
inputs = ["\n".join(dialogue) + "\nBot: " + reference for dialogue, reference in zip(batch["dialogue"], batch["reference"])]
return tokenizer(inputs, truncation=True, padding="max_length", max_length=128)
tokenized_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset["train"].column_names)
if "validation" in tokenized_dataset:
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["validation"]
else:
train_test_split = tokenized_dataset["train"].train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="steps",
eval_steps=200,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
save_strategy="steps",
save_steps=500,
logging_steps=100,
learning_rate=5e-5,
num_train_epochs=5,
warmup_steps=200,
fp16=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer
)
def train_model():
print("Starting training...")
trainer.train()
model.save_pretrained("trained_chatbot")
tokenizer.save_pretrained("trained_chatbot")
return "Training Complete!"
def chatbot(user_input):
input_text = f"User: {user_input}\nBot:"
inputs = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors="pt")
outputs = model.generate(
inputs,
max_length=200,
pad_token_id=tokenizer.eos_token_id,
temperature=0.9, # Increase for more diverse outputs
top_p=0.95, # Use nucleus sampling
repetition_penalty=1.2, # Avoid repetitive outputs
do_sample=True
)
response = tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True)
return response.strip()
iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", live=True)
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch()