vimey's picture
upload model
23519e5
# load the requirements
import torch
import os
from transformers import (
WhisperFeatureExtractor,
WhisperTokenizer, WhisperProcessor,
Seq2SeqTrainingArguments,
WhisperForConditionalGeneration,
TrainerCallback,
Seq2SeqTrainer,
)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from torch.utils.data import IterableDataset
import evaluate
from datasets import load_dataset, Audio
from dataclasses import dataclass
import pandas as pd
import subprocess
import datetime
import csv
# define the model id
model_id = "openai/insert_model_id"
# specify the output file path of the wrong predictions
output_file_path = "path/to/your/output/wrong_predictions.csv"
# specify the output file path of the computational resources data
output_file_path_gpu = "path/to/your/output/efficiency_data.csv"
# load and define the feature extractor and the tokenizer
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)
tokenizer = WhisperTokenizer.from_pretrained(model_id, language = "English", task = "transcribe")
# load audio dataset
audio_dataset_train = load_dataset("audiofolder", data_dir = "/path/to/dataset/train")
audio_dataset_test = load_dataset("audiofolder", data_dir = "/path/to/dataset/test")
# load the processor
processor = WhisperProcessor.from_pretrained(model_id, language = "English", task = "transcribe")
# preprocess the data
audio_dataset_train = audio_dataset_train.cast_column("audio", Audio(sampling_rate=16000))
audio_dataset_test = audio_dataset_test.cast_column("audio", Audio(sampling_rate=16000))
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
transcription = batch["transcription"]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
# apply "prepare dataset" function to each sample in the dataset
vectorized_audio_dataset_train = audio_dataset_train.map(
prepare_dataset,
remove_columns=list(next(iter(audio_dataset_train.values())).features)).with_format("torch")
vectorized_audio_dataset_test = audio_dataset_test.map(
prepare_dataset,
remove_columns=list(next(iter(audio_dataset_test.values())).features)).with_format("torch")
# shuffle the audioset
vectorized_audio_dataset_train["train"] = vectorized_audio_dataset_train["train"].shuffle(
seed=0,
load_from_cache_file=False).shard(
num_shards=1, index=0, contiguous=True)
# training and evaluation
# define a data collator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: any
def __call__(self, features):
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# evaluation matrix WER
metric = evaluate.load("wer")
do_normalize_eval = True
# store filenames, predictions and references
predicted_words_list = []
target_words_list = []
filenames = []
def compute_metrics(pred, specific_vocab=None):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
# filtering step to only evaluate the samples which correspond to non-zero references:
pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
# append wrong predictions and references to the respective lists, if it is a wrong prediction
for pred_word, target_word, filename in zip(pred_str, label_str, audio_dataset_test["train"]["audio"]):
if pred_word != target_word:
predicted_words_list.append(pred_word)
target_words_list.append(target_word)
filenames.append(os.path.basename(str(filename)))
print(f"WER: {wer}")
return {"wer": wer}
# load a pre-trained checkpoint
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(torch.device(0))
# disable the use of forced ids, suppressing tokens and the cache
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
# define the training parameters
training_args = Seq2SeqTrainingArguments(
output_dir="./",
save_total_limit=2,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
eval_accumulation_steps=1,
learning_rate=1e-5,
warmup_steps=100,
max_steps=1000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=25,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
)
# trainer callback to reinitialise and reshuffle the datasets at the beginning of each epoch
class ShuffleCallback(TrainerCallback):
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
if not isinstance(train_dataloader.dataset, IterableDataset):
train_dataloader.dataset.shuffle()
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=vectorized_audio_dataset_train["train"],
eval_dataset=vectorized_audio_dataset_test["train"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor,
callbacks=[ShuffleCallback()],
)
model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)
# log start and endtime of the training
start_time = datetime.datetime.now()
# launch training
trainer.train()
end_time = datetime.datetime.now()
# fill in missing values with empty strings to ensure equal lengths
max_length = max(len(filenames), len(predicted_words_list), len(target_words_list))
filenames += [""] * (max_length - len(filenames))
predicted_words_list += [""] * (max_length - len(predicted_words_list))
target_words_list += [""] * (max_length - len(target_words_list))
# save the wrong predictions
df_wrong_predictions = pd.DataFrame({
"File Name": filenames,
"Predictions": predicted_words_list,
"References": target_words_list
})
pred_words_split = [pred.split() for pred in predicted_words_list]
target_words_split = [target.split() for target in target_words_list]
filtered_pred_words = [" ".join([word for word in pred if word != target_word]) for pred, target_word in zip(pred_words_split, target_words_split)]
filtered_target_words = [" ".join([word for word in target if word != pred_word]) for target, pred_word in zip(target_words_split, pred_words_split)]
# update the DataFrame with the filtered files
df_wrong_predictions["Predictions"] = filtered_pred_words
df_wrong_predictions["References"] = filtered_target_words
df_wrong_predictions = df_wrong_predictions[df_wrong_predictions["Predictions"] != df_wrong_predictions["References"]]
# save the DataFrame as a CSV file
df_wrong_predictions.to_csv(output_file_path, index=False)
# get training speed
duration = end_time - start_time
duration_hours = duration.total_seconds() / 3600 # Convert duration to hours
# get the GPU infos
def get_gpu_info():
try:
output = subprocess.check_output(["nvidia-smi", "--query-gpu=index,name,memory.used", "--format=csv,noheader,nounits"])
gpu_info = [line.strip().split(", ") for line in output.decode("utf-8").split("\n") if line.strip()]
return gpu_info
except Exception as e:
return []
gpu_info = get_gpu_info()
if gpu_info:
gpu_name = gpu_info[0][1]
gpu_memory_used = int(gpu_info[0][2])
with open(output_file_path_gpu, mode="w", newline="") as file:
writer = csv.writer(file)
writer.writerow(["Training Duration (hours)", "GPU Name", "GPU Memory Used (MB)"])
writer.writerow([duration_hours, gpu_name, gpu_memory_used])