R1 / app.py
hackergeek98's picture
Update app.py
f576f58 verified
raw
history blame contribute delete
4.93 kB
import torch
import gradio as gr
import threading
import logging
import sys
from urllib.parse import urlparse
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
# Configure logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def parse_hf_dataset_url(url: str) -> tuple[str, str | None]:
"""Parse Hugging Face dataset URL into (dataset_name, config)"""
parsed = urlparse(url)
path_parts = parsed.path.split('/')
try:
# Find 'datasets' in path
datasets_idx = path_parts.index('datasets')
except ValueError:
raise ValueError("Invalid Hugging Face dataset URL")
dataset_parts = path_parts[datasets_idx+1:]
dataset_name = "/".join(dataset_parts[0:2])
# Try to find config (common pattern for datasets with viewer)
try:
viewer_idx = dataset_parts.index('viewer')
config = dataset_parts[viewer_idx+1] if viewer_idx+1 < len(dataset_parts) else None
except ValueError:
config = None
return dataset_name, config
def train(dataset_url: str):
try:
# Parse dataset URL
dataset_name, dataset_config = parse_hf_dataset_url(dataset_url)
logging.info(f"Loading dataset: {dataset_name} (config: {dataset_config})")
# Load model and tokenizer
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
# Add padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset from Hugging Face Hub
dataset = load_dataset(
dataset_name,
dataset_config,
trust_remote_code=True
)
# Handle dataset splits
if "train" not in dataset:
raise ValueError("Dataset must have a 'train' split")
train_dataset = dataset["train"]
eval_dataset = dataset.get("validation", dataset.get("test", None))
# Split if no validation set
if eval_dataset is None:
split = train_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]
# Tokenization function
def tokenize_function(examples):
return tokenizer(
examples["text"], # Adjust column name as needed
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt",
)
# Tokenize datasets
tokenized_train = train_dataset.map(
tokenize_function,
batched=True,
remove_columns=train_dataset.column_names
)
tokenized_eval = eval_dataset.map(
tokenize_function,
batched=True,
remove_columns=eval_dataset.column_names
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Training arguments
training_args = TrainingArguments(
output_dir="./phi2-results",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=3,
logging_dir="./logs",
logging_steps=10,
fp16=False,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
data_collator=data_collator,
)
# Start training
logging.info("Training started...")
trainer.train()
trainer.save_model("./phi2-trained-model")
logging.info("Training completed!")
return "βœ… Training succeeded! Model saved."
except Exception as e:
logging.error(f"Training failed: {str(e)}")
return f"❌ Training failed: {str(e)}"
# Gradio interface
with gr.Blocks(title="Phi-2 Training") as demo:
gr.Markdown("# πŸš€ Train Phi-2 with HF Hub Data")
with gr.Row():
dataset_url = gr.Textbox(
label="Dataset URL",
value="https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0"
)
start_btn = gr.Button("Start Training", variant="primary")
status_output = gr.Textbox(label="Status", interactive=False)
start_btn.click(
fn=lambda url: threading.Thread(target=train, args=(url,)).start(),
inputs=[dataset_url],
outputs=status_output
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860
)