import torch
import torch.nn as nn

from .ProbUNet_utils import make_onehot as make_onehot_segmentation, make_slices, match_to


def is_conv(op):
    conv_types = (nn.Conv1d,
                  nn.Conv2d,
                  nn.Conv3d,
                  nn.ConvTranspose1d,
                  nn.ConvTranspose2d,
                  nn.ConvTranspose3d)
    if type(op) == type and issubclass(op, conv_types):
        return True
    elif type(op) in conv_types:
        return True
    else:
        return False



class ConvModule(nn.Module):

    def __init__(self, *args, **kwargs):

        super(ConvModule, self).__init__()

    def init_weights(self, init_fn, *args, **kwargs):

        class init_(object):

            def __init__(self):
                self.fn = init_fn
                self.args = args
                self.kwargs = kwargs

            def __call__(self, module):
                if is_conv(type(module)):
                    module.weight = self.fn(module.weight, *self.args, **self.kwargs)

        _init_ = init_()
        self.apply(_init_)

    def init_bias(self, init_fn, *args, **kwargs):

        class init_(object):

            def __init__(self):
                self.fn = init_fn
                self.args = args
                self.kwargs = kwargs

            def __call__(self, module):
                if is_conv(type(module)) and module.bias is not None:
                    module.bias = self.fn(module.bias, *self.args, **self.kwargs)

        _init_ = init_()
        self.apply(_init_)



class ConcatCoords(nn.Module):

    def forward(self, input_):

        dim = input_.dim() - 2
        coord_channels = []
        for i in range(dim):
            view = [1, ] * dim
            view[i] = -1
            repeat = list(input_.shape[2:])
            repeat[i] = 1
            coord_channels.append(
                torch.linspace(-0.5, 0.5, input_.shape[i+2])
                .view(*view)
                .repeat(*repeat)
                .to(device=input_.device, dtype=input_.dtype))
        coord_channels = torch.stack(coord_channels).unsqueeze(0)
        repeat = [1, ] * input_.dim()
        repeat[0] = input_.shape[0]
        coord_channels = coord_channels.repeat(*repeat).contiguous()

        return torch.cat([input_, coord_channels], 1)



class InjectionConvEncoder(ConvModule):

    _default_activation_kwargs = dict(inplace=True)
    _default_norm_kwargs = dict()
    _default_conv_kwargs = dict(kernel_size=3, padding=1)
    _default_pool_kwargs = dict(kernel_size=2)
    _default_dropout_kwargs = dict()
    _default_global_pool_kwargs = dict()

    def __init__(self,
                 in_channels=1,
                 out_channels=6,
                 depth=4,
                 injection_depth="last",
                 injection_channels=0,
                 block_depth=2,
                 num_feature_maps=24,
                 feature_map_multiplier=2,
                 activation_op=nn.LeakyReLU,
                 activation_kwargs=None,
                 norm_op=nn.InstanceNorm2d,
                 norm_kwargs=None,
                 norm_depth=0,
                 conv_op=nn.Conv2d,
                 conv_kwargs=None,
                 pool_op=nn.AvgPool2d,
                 pool_kwargs=None,
                 dropout_op=None,
                 dropout_kwargs=None,
                 global_pool_op=nn.AdaptiveAvgPool2d,
                 global_pool_kwargs=None,
                 **kwargs):

        super(InjectionConvEncoder, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.depth = depth
        self.injection_depth = depth - 1 if injection_depth == "last" else injection_depth
        self.injection_channels = injection_channels
        self.block_depth = block_depth
        self.num_feature_maps = num_feature_maps
        self.feature_map_multiplier = feature_map_multiplier

        self.activation_op = activation_op
        self.activation_kwargs = self._default_activation_kwargs
        if activation_kwargs is not None:
            self.activation_kwargs.update(activation_kwargs)

        self.norm_op = norm_op
        self.norm_kwargs = self._default_norm_kwargs
        if norm_kwargs is not None:
            self.norm_kwargs.update(norm_kwargs)
        self.norm_depth = depth if norm_depth == "full" else norm_depth

        self.conv_op = conv_op
        self.conv_kwargs = self._default_conv_kwargs
        if conv_kwargs is not None:
            self.conv_kwargs.update(conv_kwargs)

        self.pool_op = pool_op
        self.pool_kwargs = self._default_pool_kwargs
        if pool_kwargs is not None:
            self.pool_kwargs.update(pool_kwargs)

        self.dropout_op = dropout_op
        self.dropout_kwargs = self._default_dropout_kwargs
        if dropout_kwargs is not None:
            self.dropout_kwargs.update(dropout_kwargs)

        self.global_pool_op = global_pool_op
        self.global_pool_kwargs = self._default_global_pool_kwargs
        if global_pool_kwargs is not None:
            self.global_pool_kwargs.update(global_pool_kwargs)

        for d in range(self.depth):

            in_ = self.in_channels if d == 0 else self.num_feature_maps * (self.feature_map_multiplier**(d-1))
            out_ = self.num_feature_maps * (self.feature_map_multiplier**d)

            if d == self.injection_depth + 1:
                in_ += self.injection_channels

            layers = []
            if d > 0:
                layers.append(self.pool_op(**self.pool_kwargs))
            for b in range(self.block_depth):
                current_in = in_ if b == 0 else out_
                layers.append(self.conv_op(current_in, out_, **self.conv_kwargs))
                if self.norm_op is not None and d < self.norm_depth:
                    layers.append(self.norm_op(out_, **self.norm_kwargs))
                if self.activation_op is not None:
                    layers.append(self.activation_op(**self.activation_kwargs))
                if self.dropout_op is not None:
                    layers.append(self.dropout_op(**self.dropout_kwargs))
            if d == self.depth - 1:
                current_conv_kwargs = self.conv_kwargs.copy()
                current_conv_kwargs["kernel_size"] = 1
                current_conv_kwargs["padding"] = 0
                current_conv_kwargs["bias"] = False
                layers.append(self.conv_op(out_, out_channels, **current_conv_kwargs))

            self.add_module("encode_{}".format(d), nn.Sequential(*layers))

        if self.global_pool_op is not None:
            self.add_module("global_pool", self.global_pool_op(1, **self.global_pool_kwargs))

    def forward(self, x, injection=None):

        for d in range(self.depth):
            x = self._modules["encode_{}".format(d)](x)
            if d == self.injection_depth and self.injection_channels > 0:
                injection = match_to(injection, x, self.injection_channels)
                x = torch.cat([x, injection], 1)
        if hasattr(self, "global_pool"):
            x = self.global_pool(x)

        return x


class InjectionConvEncoder3D(InjectionConvEncoder):

    def __init__(self, *args, **kwargs):

        update_kwargs = dict(
                norm_op=nn.InstanceNorm3d,
                conv_op=nn.Conv3d,
                pool_op=nn.AvgPool3d,
                global_pool_op=nn.AdaptiveAvgPool3d
            )

        for (arg, val) in update_kwargs.items():
            if arg not in kwargs: kwargs[arg] = val

        super(InjectionConvEncoder3D, self).__init__(*args, **kwargs)

class InjectionConvEncoder2D(InjectionConvEncoder): #Created by Soumick
    
    def __init__(self, *args, **kwargs):

        update_kwargs = dict(
                norm_op=nn.InstanceNorm2d,
                conv_op=nn.Conv2d,
                pool_op=nn.AvgPool2d,
                global_pool_op=nn.AdaptiveAvgPool2d
            )

        for (arg, val) in update_kwargs.items():
            if arg not in kwargs: kwargs[arg] = val

        super(InjectionConvEncoder2D, self).__init__(*args, **kwargs)

class InjectionUNet(ConvModule):

    def __init__(
        self,
        depth=5,
        in_channels=4,
        out_channels=4,
        kernel_size=3,
        dilation=1,
        num_feature_maps=24,
        block_depth=2,
        num_1x1_at_end=3,
        injection_channels=3,
        injection_at="end",
        activation_op=nn.LeakyReLU,
        activation_kwargs=None,
        pool_op=nn.AvgPool2d,
        pool_kwargs=dict(kernel_size=2),
        dropout_op=None,
        dropout_kwargs=None,
        norm_op=nn.InstanceNorm2d,
        norm_kwargs=None,
        conv_op=nn.Conv2d,
        conv_kwargs=None,
        upconv_op=nn.ConvTranspose2d,
        upconv_kwargs=None,
        output_activation_op=None,
        output_activation_kwargs=None,
        return_bottom=False,
        coords=False,
        coords_dim=2,
        **kwargs
    ):

        super(InjectionUNet, self).__init__(**kwargs)

        self.depth = depth
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = (self.kernel_size + (self.kernel_size-1) * (self.dilation-1)) // 2
        self.num_feature_maps = num_feature_maps
        self.block_depth = block_depth
        self.num_1x1_at_end = num_1x1_at_end
        self.injection_channels = injection_channels
        self.injection_at = injection_at
        self.activation_op = activation_op
        self.activation_kwargs = {} if activation_kwargs is None else activation_kwargs
        self.pool_op = pool_op
        self.pool_kwargs = {} if pool_kwargs is None else pool_kwargs
        self.dropout_op = dropout_op
        self.dropout_kwargs = {} if dropout_kwargs is None else dropout_kwargs
        self.norm_op = norm_op
        self.norm_kwargs = {} if norm_kwargs is None else norm_kwargs
        self.conv_op = conv_op
        self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs
        self.upconv_op = upconv_op
        self.upconv_kwargs = {} if upconv_kwargs is None else upconv_kwargs
        self.output_activation_op = output_activation_op
        self.output_activation_kwargs = {} if output_activation_kwargs is None else output_activation_kwargs
        self.return_bottom = return_bottom
        if not coords:
            self.coords = [[], []]
        elif coords is True:
            self.coords = [list(range(depth)), []]
        else:
            self.coords = coords
        self.coords_dim = coords_dim

        self.last_activations = None

        # BUILD ENCODER
        for d in range(self.depth):

            block = []
            if d > 0:
                block.append(self.pool_op(**self.pool_kwargs))

            for i in range(self.block_depth):

                # bottom block fixed to have depth 1
                if d == self.depth - 1 and i > 0:
                    continue

                out_size = self.num_feature_maps * 2**d
                if d == 0 and i == 0:
                    in_size = self.in_channels
                elif i == 0:
                    in_size = self.num_feature_maps * 2**(d - 1)
                else:
                    in_size = out_size

                # check for coord appending at this depth
                if d in self.coords[0] and i == 0:
                    block.append(ConcatCoords())
                    in_size += self.coords_dim

                block.append(self.conv_op(in_size,
                                          out_size,
                                          self.kernel_size,
                                          padding=self.padding,
                                          dilation=self.dilation,
                                          **self.conv_kwargs))
                if self.dropout_op is not None:
                    block.append(self.dropout_op(**self.dropout_kwargs))
                if self.norm_op is not None:
                    block.append(self.norm_op(out_size, **self.norm_kwargs))
                block.append(self.activation_op(**self.activation_kwargs))

            self.add_module("encode-{}".format(d), nn.Sequential(*block))

        # BUILD DECODER
        for d in reversed(range(self.depth)):

            block = []

            for i in range(self.block_depth):

                # bottom block fixed to have depth 1
                if d == self.depth - 1 and i > 0:
                    continue

                out_size = self.num_feature_maps * 2**(d)
                if i == 0 and d < self.depth - 1:
                    in_size = self.num_feature_maps * 2**(d+1)
                elif i == 0 and self.injection_at == "bottom":
                    in_size = out_size + self.injection_channels
                else:
                    in_size = out_size

                # check for coord appending at this depth
                if d in self.coords[0] and i == 0 and d < self.depth - 1:
                    block.append(ConcatCoords())
                    in_size += self.coords_dim

                block.append(self.conv_op(in_size,
                                          out_size,
                                          self.kernel_size,
                                          padding=self.padding,
                                          dilation=self.dilation,
                                          **self.conv_kwargs))
                if self.dropout_op is not None:
                    block.append(self.dropout_op(**self.dropout_kwargs))
                if self.norm_op is not None:
                    block.append(self.norm_op(out_size, **self.norm_kwargs))
                block.append(self.activation_op(**self.activation_kwargs))

            if d > 0:
                block.append(self.upconv_op(out_size,
                                            out_size // 2,
                                            self.kernel_size,
                                            2,
                                            padding=self.padding,
                                            dilation=self.dilation,
                                            output_padding=1,
                                            **self.upconv_kwargs))

            self.add_module("decode-{}".format(d), nn.Sequential(*block))

        if self.injection_at == "end":
            out_size += self.injection_channels
        in_size = out_size
        for i in range(self.num_1x1_at_end):
            if i == self.num_1x1_at_end - 1:
                out_size = self.out_channels
            current_conv_kwargs = self.conv_kwargs.copy()
            current_conv_kwargs["bias"] = True
            self.add_module("reduce-{}".format(i), self.conv_op(in_size, out_size, 1, **current_conv_kwargs))
            if i != self.num_1x1_at_end - 1:
                self.add_module("reduce-{}-nonlin".format(i), self.activation_op(**self.activation_kwargs))
        if self.output_activation_op is not None:
            self.add_module("output-activation", self.output_activation_op(**self.output_activation_kwargs))

    def reset(self):

        self.last_activations = None

    def forward(self, x, injection=None, reuse_last_activations=False, store_activations=False):

        if self.injection_at == "bottom":  # not worth it for now
            reuse_last_activations = False
            store_activations = False

        if self.last_activations is None or reuse_last_activations is False:

            enc = [x]

            for i in range(self.depth - 1):
                enc.append(self._modules["encode-{}".format(i)](enc[-1]))

            bottom_rep = self._modules["encode-{}".format(self.depth - 1)](enc[-1])

            if self.injection_at == "bottom" and self.injection_channels > 0:
                injection = match_to(injection, bottom_rep, (0, 1))
                bottom_rep = torch.cat((bottom_rep, injection), 1)

            x = self._modules["decode-{}".format(self.depth - 1)](bottom_rep)

            for i in reversed(range(self.depth - 1)):
                x = self._modules["decode-{}".format(i)](torch.cat((enc[-(self.depth - 1 - i)], x), 1))

            if store_activations:
                self.last_activations = x.detach()

        else:

            x = self.last_activations

        if self.injection_at == "end" and self.injection_channels > 0:
            injection = match_to(injection, x, (0, 1))
            x = torch.cat((x, injection), 1)

        for i in range(self.num_1x1_at_end):
            x = self._modules["reduce-{}".format(i)](x)
        if self.output_activation_op is not None:
            x = self._modules["output-activation"](x)

        if self.return_bottom and not reuse_last_activations:
            return x, bottom_rep
        else:
            return x



class InjectionUNet3D(InjectionUNet):

    def __init__(self, *args, **kwargs):

        update_kwargs = dict(
                pool_op=nn.AvgPool3d,
                norm_op=nn.InstanceNorm3d,
                conv_op=nn.Conv3d,
                upconv_op=nn.ConvTranspose3d,
                coords_dim=3
            )

        for (arg, val) in update_kwargs.items():
            if arg not in kwargs: kwargs[arg] = val

        super(InjectionUNet3D, self).__init__(*args, **kwargs)

class InjectionUNet2D(InjectionUNet): #Created by Soumick
    
    def __init__(self, *args, **kwargs):

        update_kwargs = dict(
                pool_op=nn.AvgPool2d,
                norm_op=nn.InstanceNorm2d,
                conv_op=nn.Conv2d,
                upconv_op=nn.ConvTranspose2d,
                coords_dim=2
            )

        for (arg, val) in update_kwargs.items():
            if arg not in kwargs: kwargs[arg] = val

        super(InjectionUNet2D, self).__init__(*args, **kwargs)

class ProbabilisticSegmentationNet(ConvModule):

    def __init__(self,
                 in_channels=4,
                 out_channels=4,
                 num_feature_maps=24,
                 latent_size=3,
                 depth=5,
                 latent_distribution=torch.distributions.Normal,
                 task_op=InjectionUNet3D,
                 task_kwargs=None,
                 prior_op=InjectionConvEncoder3D,
                 prior_kwargs=None,
                 posterior_op=InjectionConvEncoder3D,
                 posterior_kwargs=None,
                 **kwargs):

        super(ProbabilisticSegmentationNet, self).__init__(**kwargs)

        self.task_op = task_op
        self.task_kwargs = {} if task_kwargs is None else task_kwargs
        self.prior_op = prior_op
        self.prior_kwargs = {} if prior_kwargs is None else prior_kwargs
        self.posterior_op = posterior_op
        self.posterior_kwargs = {} if posterior_kwargs is None else posterior_kwargs

        default_task_kwargs = dict(
            in_channels=in_channels,
            out_channels=out_channels,
            num_feature_maps=num_feature_maps,
            injection_size=latent_size,
            depth=depth
        )

        default_prior_kwargs = dict(
            in_channels=in_channels,
            out_channels=latent_size*2, #Soumick
            num_feature_maps=num_feature_maps,
            z_dim=latent_size,
            depth=depth
        )

        default_posterior_kwargs = dict(
            in_channels=in_channels+out_channels,
            out_channels=latent_size*2, #Soumick
            num_feature_maps=num_feature_maps,
            z_dim=latent_size,
            depth=depth
        )

        default_task_kwargs.update(self.task_kwargs)
        self.task_kwargs = default_task_kwargs
        default_prior_kwargs.update(self.prior_kwargs)
        self.prior_kwargs = default_prior_kwargs
        default_posterior_kwargs.update(self.posterior_kwargs)
        self.posterior_kwargs = default_posterior_kwargs

        self.latent_distribution = latent_distribution
        self._prior = None
        self._posterior = None

        self.make_modules()

    def make_modules(self):

        if type(self.task_op) == type:
            self.add_module("task_net", self.task_op(**self.task_kwargs))
        else:
            self.add_module("task_net", self.task_op)
        if type(self.prior_op) == type:
            self.add_module("prior_net", self.prior_op(**self.prior_kwargs))
        else:
            self.add_module("prior_net", self.prior_op)
        if type(self.posterior_op) == type:
            self.add_module("posterior_net", self.posterior_op(**self.posterior_kwargs))
        else:
            self.add_module("posterior_net", self.posterior_op)

    @property
    def prior(self):
        return self._prior

    @property
    def posterior(self):
        return self._posterior

    @property
    def last_activations(self):
        return self.task_net.last_activations

    def train(self, mode=True):

        super(ProbabilisticSegmentationNet, self).train(mode)
        self.reset()

    def reset(self):

        self.task_net.reset()
        self._prior = None
        self._posterior = None

    def forward(self, input_, seg=None, make_onehot=True, make_onehot_classes=None, newaxis=False, distlossN=0):
        """Forward pass includes reparametrization sampling during training, otherwise it'll just take the prior mean."""

        self.encode_prior(input_)

        if distlossN == 0:
            if self.training:
                self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis)
                sample = self.posterior.rsample()
            else:
                sample = self.prior.loc
            return self.task_net(input_, sample, store_activations=not self.training)
        else:
            if self.training:
                self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis)
                segs = []
                for i in range(distlossN):
                    sample = self.posterior.rsample()
                    segs.append(self.task_net(input_, sample, store_activations=not self.training))
                return segs #torch.concat(segs, dim=0)
            else: #I'm not totally sure about this!!
                sample = self.prior.loc
                return self.task_net(input_, sample, store_activations=not self.training)


    def encode_prior(self, input_):

        rep = self.prior_net(input_)
        if isinstance(rep, tuple):
            mean, logvar = rep
        elif torch.is_tensor(rep):
            mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1)
        self._prior = self.latent_distribution(mean, logvar.mul(0.5).exp())
        return self._prior

    def encode_posterior(self, input_, seg, make_onehot=True, make_onehot_classes=None, newaxis=False):

        if make_onehot:
            if make_onehot_classes is None:
                make_onehot_classes = tuple(range(self.posterior_net.in_channels - input_.shape[1]))
            seg = make_onehot_segmentation(seg, make_onehot_classes, newaxis=newaxis)
        rep = self.posterior_net(torch.cat((input_, seg.float()), 1))
        if isinstance(rep, tuple):
            mean, logvar = rep
        elif torch.is_tensor(rep):
            mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1)
        self._posterior = self.latent_distribution(mean, logvar.mul(0.5).exp())
        return self._posterior

    def sample_prior(self, N=1, out_device=None, input_=None, pred_with_mean=False):
        """Draw multiple samples from the current prior.
        
        * input_ is required if no activations are stored in task_net.
        * If input_ is given, prior will automatically be encoded again.
        * Returns either a single sample or a list of samples.

        """

        if out_device is None:
            if self.last_activations is not None:
                out_device = self.last_activations.device
            elif input_ is not None:
                out_device = input_.device
            else:
                out_device = next(self.task_net.parameters()).device
        with torch.no_grad():
            if self.prior is None or input_ is not None:
                self.encode_prior(input_)
            result = []
            
            if input_ is not None:
                result.append(self.task_net(input_, self.prior.sample(), reuse_last_activations=False, store_activations=True).to(device=out_device))
            while len(result) < N:
                result.append(self.task_net(input_,
                                            self.prior.sample(),
                                            reuse_last_activations=self.last_activations is not None,
                                            store_activations=False).to(device=out_device))
            if pred_with_mean:
                result.append(self.task_net(input_, self.prior.mean, reuse_last_activations=False, store_activations=True).to(device=out_device))
            
            if len(result) == 1:
                return result[0]
            else:
                return result

    def reconstruct(self, sample=None, use_posterior_mean=True, out_device=None, input_=None):
        """Reconstruct a sample or the current posterior mean. Will not compute gradients!"""

        if self.posterior is None and sample is None:
            raise ValueError("'posterior' is currently None. Please pass an input and a segmentation first.")
        if out_device is None:
            out_device = next(self.task_net.parameters()).device
        if sample is None:
            if use_posterior_mean:
                sample = self.posterior.loc
            else:
                sample = self.posterior.sample()
        else:
            sample = sample.to(next(self.task_net.parameters()).device)
        with torch.no_grad():
            return self.task_net(input_, sample, reuse_last_activations=True).to(device=out_device)

    def kl_divergence(self):
        """Compute current KL, requires existing prior and posterior."""

        if self.posterior is None or self.prior is None:
            raise ValueError("'prior' and 'posterior' must not be None, but prior={} and posterior={}".format(self.prior, self.posterior))
        return torch.distributions.kl_divergence(self.posterior, self.prior).sum()

    def elbo(self, seg, input_=None, nll_reduction="sum", beta=1.0, make_onehot=True, make_onehot_classes=None, newaxis=False):
        """Compute the ELBO with seg as ground truth.

        * Prior is expected and will not be encoded.
        * If input_ is given, posterior will automatically be encoded.
        * Either input_ or stored activations must be available.

        """

        if self.last_activations is None:
            raise ValueError("'last_activations' is currently None. Please pass an input first.")
        if input_ is not None:
            with torch.no_grad():
                self.encode_posterior(input_, seg, make_onehot=make_onehot, make_onehot_classes=make_onehot_classes, newaxis=newaxis)
        if make_onehot and newaxis:
            pass  # seg will already be (B x SPACE)
        elif make_onehot and not newaxis:
            seg = seg[:, 0]  # in this case seg will hopefully be (B x 1 x SPACE)
        else:
            seg = torch.argmax(seg, 1, keepdim=False)  # seg is already onehot
        kl = self.kl_divergence()
        nll = nn.NLLLoss(reduction=nll_reduction)(self.reconstruct(sample=None, use_posterior_mean=True, out_device=None), seg.long())
        return - (beta * nll + kl)