import pandas as pd
import torch
from datetime import datetime
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import IterableDataset, Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn import TransformerEncoderLayer

from torch.distributed.pipeline.sync import Pipe
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import numpy as np
import math

import idr_torch
import os
import tempfile

import argparse
from datetime import datetime
from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity, schedule

# For reproducibility
torch.manual_seed(53)
np.random.seed(53)

# To get parameters
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size_per_gpu", type=int, default=32)
    parser.add_argument("--lr", type=float, default=5e-05)
    parser.add_argument("--d_model", type=int, default=768)
    parser.add_argument("--n_head", type=int, default=12)
    parser.add_argument("--d_hid", type=int, default=3072)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--ngpu", type=int, default=torch.cuda.device_count())
    parser.add_argument("--nlayers", type=int, default=8)
    parser.add_argument("--chunks", type=int, default=8)
    parser.add_argument("--train_file", type=str, default="/gpfswork/idris/sos/ssos022/datasets/imdb/dataset_train.csv")
    parser.add_argument("--valid_file", type=str, default="/gpfswork/idris/sos/ssos022/datasets/imdb/dataset_val.csv")
    parser.add_argument("--profile", action='store_true')
    parser.add_argument("--timestamp", type=int, default=int(datetime.timestamp(datetime.now())))
    
    nb_part = torch.cuda.device_count()//int(os.environ['SLURM_NTASKS_PER_NODE'])
    parser.add_argument("--nb_part", type=int, default=nb_part)
    parser.add_argument("--first_part", type=int, default=(idr_torch.local_rank*nb_part))
    parser.add_argument("--last_part", type=int, default=(idr_torch.local_rank*nb_part+nb_part-1))


    args = parser.parse_args()
    return args


def print_rank_0(*to_print, **kwargs):
    if idr_torch.rank == 0:
        print(*to_print, **kwargs)

########################## DATA PreProcess ##########################
# Initialize the preprocessing pipeline
# Not optimize for multi process
def init_voc_tok(train_file):

    class IterDatasetIMDB(IterableDataset): 

        def __init__(self, df): 
            self.df = df

        def __len__(self):
            return self.df.shape[0]

        def parse_df(self):
            for idx in range(self.df.shape[0]):
                yield self.df.iloc[idx, 2]
        
        def __iter__(self):
            return self.parse_df()
        
    df = pd.read_csv(train_file)
    iter_dataset = IterDatasetIMDB(df=df)

    tokenizer = get_tokenizer('basic_english') 
    vocab = build_vocab_from_iterator(map(tokenizer, iter_dataset), min_freq=5, specials=['<pad>','<unk>'])
    vocab.set_default_index(vocab['<unk>'])
    #print('vocab size :', len(vocab))
    return vocab, tokenizer

# Define the preprocessing pipeline
def process_batch(seq_text, vocab, tokenizer, max_len=200):
    
    data = [torch.tensor(vocab(tokenizer(item)[:max_len]), dtype=torch.long) for item in seq_text]
    data = pad_sequence(data, batch_first=True)
    mask = (data == torch.zeros(data.shape))
    return data, mask


########################## DATA Training ##########################
# Define the pytorch dataset use for training
class DatasetIMDB(Dataset):

    def __init__(self, df):
        self.df = df

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        return self.df.iloc[idx, 2], self.df.iloc[idx, 1]

# Define the pipeline to get datasets from a file
def get_dataset(source='/gpfswork/idris/sos/ssos022/datasets/imdb/dataset_train.csv', batch_size_per_gpu=32):

    df = pd.read_csv(source)
    dataset = DatasetIMDB(df=df)

    sampler = DistributedSampler(dataset, num_replicas=idr_torch.size, rank=idr_torch.rank)
    dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=sampler)

    return dataloader


############################## MODEL ##############################
# The final model will be wrap in a nn.sequential module to make it compatible with torch.distributed.pipeline.sync.Pipe
# We will cut the modules into several parts and distribute these parts in several workers/gpu (pipeline parallelism)
# Each part will be wrap in a nn.sequential module before being wrap in a final nn.sequential module

# Define positionnal encoding
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Define the encoding module (ids lists -> tensors)
class Encoder(nn.Module):
    def __init__(self, ntoken, d_model, dropout=0.5):
        super(Encoder, self).__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, pad_mask):
        # Need (S, N) format for encoder.
        src = self.encoder(src) * math.sqrt(self.d_model)
        return self.pos_encoder(src), pad_mask

# Define the transformer layer
# We need to wrap the TransformerEncoderLayer in a custom module to make it usable in a nn.Sequential module with 2 inputs (it works only when the model is wrap in a Pipeline Module)
class Layer(nn.Module):
    def __init__(self, d_model, nhead, d_hid, dropout=0.1):
        super(Layer, self).__init__()
        self.layer = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)

    def forward(self, inp: Tensor, pad_mask: Tensor):
        out = self.layer(inp, src_key_padding_mask=pad_mask)
        return out, pad_mask

# Define the head of the model
class Decoder(nn.Module):
    def __init__(self, d_model, d_hid):
        super(Decoder, self).__init__()
        self.hid = nn.Linear(d_model, d_hid)
        self.decoder = nn.Linear(d_hid, 1)
        self.act = torch.nn.Sigmoid()
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, inp, pad_mask):
        output = self.hid(inp[:,0,:])
        output = self.decoder(output)
        output = self.act(output).view(-1)
        return output


# Define a function to instanciate the model with the right parameters
def get_model(args):
    module_list = []    
    partition_len = max((args.nlayers / args.nb_part), 1)
        
    # Add encoder in the first gpu.
    tmp_list = [Encoder(args.ntokens, args.d_model, args.dropout).to(args.first_part)]
    

    # Add all the necessary transformer blocks.
    for i in range(args.nlayers):
        transformer_block = Layer(args.d_model, args.n_head, args.d_hid, args.dropout)
        if i != 0 and i % (partition_len) == 0:
            module_list.append(nn.Sequential(*tmp_list))
            tmp_list = []
        device = int(i // (partition_len))
        tmp_list.append(transformer_block.to(args.first_part+device))
    
    # Add the head in the last gpu
    tmp_list.append(Decoder(args.d_model, args.d_hid).to(args.last_part))
    module_list.append(nn.Sequential(*tmp_list))
    model = Pipe(nn.Sequential(*module_list), chunks = args.chunks, checkpoint="never")
    return model


############################## TRAIN ##############################
# Define the evaluation metric in a function, here accuracy
def get_evaluation(y_true, y_prob):
    # accuracy = accuracy_score(y_true, y_prob)
    y_true = y_true.cpu().detach().numpy()
    y_prob = y_prob.cpu().detach().numpy()
    y_prob = np.where(y_prob <= 0.5, 0, y_prob)
    y_prob = np.where(y_prob > 0.5, 1, y_prob)

    accuracy = 1 - np.sum(np.absolute(y_true - y_prob))/len(y_true)
    return accuracy

# Train function
def train(args, vocab, tokenizer, train_loader, valid_loader, model, criterion, optimizer, prof=None):
    for epoch in range(args.epochs):
        model.train()
        for i, (texts, labels) in enumerate(train_loader):
            # Process the data to make it usablle by the model and send it to the first gpu
            batch_c, mask_c = process_batch(texts, vocab, tokenizer)
            batch = batch_c.to(args.first_part, non_blocking=True)
            mask = mask_c.to(args.first_part, non_blocking=True)
            labels = labels.to(args.last_part, non_blocking=True)
            
            # Since the Pipe is only within a single host and process the ``RRef``
            # returned by forward method is local to this node and can simply
            # retrieved via ``RRef.local_value()``.
            outputs = model(batch, mask).local_value()
            loss = criterion(outputs, labels.float())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if prof: prof.step()
            
            if (i+1)%1 == 0:
                #print(os.system('nvidia-smi'))
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, len(train_loader), loss.item()))

                
            
        model.eval()
        with torch.no_grad():
            accuracy = 0
            n = 0
            for i, (texts, labels) in enumerate(valid_loader):
                batch, mask = process_batch(texts, vocab, tokenizer)
                batch = batch.to(args.first_part, non_blocking=True)
                mask = mask.to(args.first_part, non_blocking=True)
                labels = labels.to(args.last_part, non_blocking=True)

                outputs = model(batch, mask).local_value()
                accuracy += get_evaluation(labels, outputs)
                n += 1
                
                if (i+1)%10 == 0:
                    print('Validation Epoch [{}/{}], Step [{}/{}]'.format(epoch + 1, args.epochs, i + 1, len(valid_loader)))
                

            accuracy = accuracy/n
            print(' Accuracy: ', accuracy)

    return model


def main(args, vocab, tokenizer):
    train_loader = get_dataset(source=args.train_file, batch_size_per_gpu=args.batch_size_per_gpu)
    valid_loader = get_dataset(source=args.valid_file, batch_size_per_gpu=args.batch_size_per_gpu)
    model = DDP(get_model(args))
    criterion = nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    if args.profile:
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                        schedule=schedule(wait=1, warmup=1, active=5, repeat=1),
                        on_trace_ready=tensorboard_trace_handler(f'./profiler/{os.environ["SLURM_JOB_NAME"]}/{args.timestamp}_{os.environ["SLURMD_NODENAME"]}'),
                        profile_memory=True,
                        #record_shapes=False,
                        #with_stack=False
                        ) as prof:
            model = train(args, vocab, tokenizer, train_loader, valid_loader, model, criterion, optimizer, prof=prof)

    else:
        model = train(args, vocab, tokenizer, train_loader, valid_loader, model, criterion, optimizer, prof=None)

def init_train():
    args = parse_args()
    print_rank_0(f">>> Training on {len(idr_torch.hostnames)} nodes, {torch.cuda.device_count()} gpus",
                 f"and {idr_torch.size} processes, master node is {os.environ['MASTER_ADDR']} \n")
    print(f"- Process {idr_torch.rank} corresponds to GPU {idr_torch.local_rank} of node {os.environ['SLURM_NODEID']}\n")

    dist.init_process_group(backend='nccl', init_method='env://', world_size=idr_torch.size, rank=idr_torch.rank)

    # Initialize RPC Framework, Pipe depends on it
    tmpfile = tempfile.NamedTemporaryFile()
    dist.rpc.init_rpc(name="worker", rank=0, world_size=1,
        rpc_backend_options=dist.rpc.TensorPipeRpcBackendOptions(
            init_method="file://{}".format(tmpfile.name),
            # Specifying _transports and _channels is a workaround and we no longer
            # will have to specify _transports and _channels for PyTorch 
            # versions >= 1.8.1 (Not True for Jean Zay)
            # With Jean Zay, _transports must be equal to ["shm", "uv"] and not ["ibv", "uv"] (like in pytorch doc)
            _transports=["shm", "uv"], _channels=["cuda_ipc", "cuda_basic"],
        )
    )

    vocab, tokenizer = init_voc_tok(args.train_file)
    args.ntokens = len(vocab)
    

    return args, vocab, tokenizer

if __name__ == '__main__':

    args, vocab, tokenizer = init_train()
    
    if idr_torch.rank == 0: start = datetime.now()
    main(args, vocab, tokenizer)
    
    # display result of bench
    
    for part in range(args.nb_part):
        print(f'max memory for gpu {idr_torch.rank*args.nb_part+part} : {torch.cuda.max_memory_allocated(device=args.first_part+part)*1e-09}')
            
    if idr_torch.rank == 0:
        print(">>> Training complete in: " + str(datetime.now() - start))
    