import argparse
import functools
import glob
import os
import random
import string
import json
import sys 
sys.path.append('../')
from tqdm import tqdm
import yaml
from collections import defaultdict
import io
import warnings
import subprocess
import pickle

import numpy as np
import torch

from data.data import get_audiotext_dataloader
from src.factory import create_model_and_transforms
from train.train_utils import Dict2Class, get_autocast, get_cast_dtype

def inference_this(
    args, data_config, clap_config, model_config, test_dataset_name, tmp_file,
    temperature=1.0, num_beams=3, ckpt=-1, end_batch_idx=-2, verbose=False,
):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # disable the tokenizer parallelism warning
    model, tokenizer = create_model_and_transforms(
        **model_config,
        clap_config=clap_config, 
        use_local_files=args.offline,
        gradient_checkpointing=args.gradient_checkpointing,
        freeze_lm_embeddings=args.freeze_lm_embeddings,
    )

    device_id = 0
    model = model.to(device_id)
    model.eval()

    if ckpt == -1:
        checkpoint_list = glob.glob(f"{args.expdir}/{args.run_name}/checkpoint_*.pt")
        resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
    else:
        resume_from_checkpoint = f"{args.expdir}/{args.run_name}/checkpoint_{ckpt}.pt"
    checkpoint = torch.load(resume_from_checkpoint, map_location="cpu")
    msd = checkpoint["model_state_dict"]
    msd = {k.replace("module.", ""): v for k, v in msd.items()}
    x,y = model.load_state_dict(msd, False)
    print(x)
    print(y)
    
    autocast = get_autocast(
        args.precision, cache_enabled=(not args.fsdp)
    )
    cast_dtype = get_cast_dtype(args.precision)

    # model = model.to(dtype=cast_dtype)

    if test_dataset_name in data_config["valid_dataset_config"]:
        data_config["valid_dataset_config"] = {test_dataset_name: data_config["valid_dataset_config"][test_dataset_name]}
    else:
        data_config["valid_dataset_config"] = {test_dataset_name: True}
    
    all_test_AudioTextDataInfo = get_audiotext_dataloader(data_config, clap_config, tokenizer, args.batch_size, split='test')
    
    assert test_dataset_name in list(all_test_AudioTextDataInfo.keys()), "{} not a test set".format(test_dataset_name)
    dataloader = all_test_AudioTextDataInfo[test_dataset_name].dataloader

    deduplicate_tasks = ["Clotho-v2-AudioCaptioning", "audiocaps-AudioCaptioning", "MACS-AudioCaptioning", "LP-MusicCaps-MSD-AudioCaptioning", "LP-MusicCaps-MC-AudioCaptioning"]
    if any([test_dataset_name.startswith(x) for x in deduplicate_tasks]):
        deduplicate = True 
    else:
        deduplicate = False

    if os.path.exists(tmp_file):
        with open(tmp_file, 'rb') as pickle_file:
            tmp_data = pickle.load(pickle_file)
        results_dic = tmp_data['results_dic']
        results = tmp_data['results']
        finished_batches = tmp_data['finished_batches']
        print('reading tmp data from {}: {} batches already computed'.format(tmp_file, finished_batches+1))
    
    else:
        tmp_data = {}
        results_dic = {}  # for deduplicate
        results = []  # for non-deduplicate
        finished_batches = -1
        print('no tmp data found; will store tmp data to {}'.format(tmp_file))

    # print(len(dataloader))
    # print('---------------------')
    from itertools import islice
    for batch_idx, batch in tqdm(enumerate(islice(dataloader, finished_batches, None), start=finished_batches)):
    # for batch_idx, batch in tqdm(enumerate(dataloader)):
        if end_batch_idx > 0 and batch_idx == end_batch_idx:
            break
        
        if batch_idx <= finished_batches:
            continue

        audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True)
        audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True)
        input_ids = batch["input_ids"].to(device_id, non_blocking=True)
        filenames = batch["filenames"]
        # print(input_ids)

        media_token_id = tokenizer.encode("<audio>")[-1]
        sep_token_id = tokenizer.sep_token_id

        for idx in range(input_ids.shape[0]):
            filename = filenames[idx]
            if type(filename) is list:
                # interleaved data
                filename = filename[-1]

            input_id = input_ids[idx]
            for sep_location in range(len(input_id)-1, -1, -1):
                # find last <SEP>
                if input_id[sep_location] == sep_token_id:
                    break
            # print(tokenizer.decode(input_id))
            prompt = input_id[:sep_location+1]

            prompt_decoded = tokenizer.decode(prompt).replace(tokenizer.sep_token, '')
            ground_truth_decoded = tokenizer.decode(input_id).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '')
            
            if not (deduplicate and (filename, prompt_decoded) in results_dic):
                # print(prompt)
                # print(prompt_decoded)
                output = model.generate(
                    audio_x=audio_clips[idx].unsqueeze(0),
                    audio_x_mask=audio_embed_mask[idx].unsqueeze(0),
                    lang_x=prompt.unsqueeze(0),
                    eos_token_id=tokenizer.eos_token_id,
                    max_new_tokens=256,
                    temperature=temperature,
                )[0]
                output_decoded = tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '')
                # print(ground_truth_decoded)
                # print('------')
                # print(output_decoded)

            if deduplicate:
                if (filename, prompt_decoded) in results_dic:
                    results_dic[(filename, prompt_decoded)]['ground_truth'].append(ground_truth_decoded)
            
                else:
                    results_dic[(filename, prompt_decoded)] = {
                        'ground_truth': [ground_truth_decoded], 
                        'output': output_decoded
                    }
            else:
                results.append((filename, prompt_decoded, ground_truth_decoded, output_decoded))
                

        tmp_data['results_dic'] = results_dic
        tmp_data['results'] = results
        tmp_data['finished_batches'] = batch_idx
        with open(tmp_file, 'wb') as pickle_file:
            pickle.dump(tmp_data, pickle_file)

    if deduplicate:
        for (filename, prompt) in results_dic:
            ground_truth = '|'.join(results_dic[(filename, prompt)]['ground_truth'])
            output = results_dic[(filename, prompt)]['output']
            results.append((filename, prompt, ground_truth, output))

    # if verbose:
    #     for filename, prompt, ground_truth, output in results:
    #         print('-'*30)
    #         print('filename:', filename)
    #         print('prompt:', prompt)
    #         print('ground_truth:', ground_truth)
    #         print('output:', output)

    return results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, default='../config/config.yaml', help='yaml config path')
    parser.add_argument('-t', '--task', type=str, help='which task to inference')
    parser.add_argument('-temp', '--temperature', type=float, default=1.0, help='temperature')
    parser.add_argument('-nb', '--num_beams', type=int, default=1, help='num beams for beam search')
    parser.add_argument('--ckpt', type=int, default=-1, help='checkpoint idx, -1 means latest')
    parsed_args = parser.parse_args()

    print(parsed_args)

    test_dataset_name = parsed_args.task

    output_file = os.path.join(
        '../outputs/', 
        parsed_args.task.replace('/', '-'), 
        '{}-ckpt{}-{}.log'.format(
            parsed_args.config.split('/')[-1][:-5], 
            parsed_args.ckpt,
            "sft"
        )
    )
    tmp_file = output_file.replace('.log', '.tmp.pickle')
    print('output file:', output_file)

    print('no previous log file; generating samples')

    config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader)
    # print(config)
    # print('----------------------')
    data_config = config['data_config']
    model_config = config['model_config']
    print(model_config)
    clap_config = config['clap_config']
    clap_config = config['clap_config']
    mert_config = config['mert_config']
    args = Dict2Class(config['train_config'])

    results = inference_this(
        args, data_config, clap_config, model_config, test_dataset_name, 
        temperature=float(parsed_args.temperature),
        num_beams=int(parsed_args.num_beams),
        ckpt=parsed_args.ckpt,
        verbose=True,
        tmp_file=tmp_file,
    )

if __name__ == "__main__":
    main()