# Licensed under a 3-clause BSD style license - see LICENSE.rst


# Note: This file incldues code dervived from pywcsgrid2
#
# This file contains Matplotlib transformation objects (e.g. from pixel to world
# coordinates, but also world-to-world).

import abc

import numpy as np

from matplotlib.path import Path
from matplotlib.transforms import Transform

from astropy import units as u
from astropy.coordinates import (SkyCoord, frame_transform_graph,
                                 UnitSphericalRepresentation,
                                 BaseCoordinateFrame)

__all__ = ['CurvedTransform', 'CoordinateTransform',
           'World2PixelTransform', 'Pixel2WorldTransform']


class CurvedTransform(Transform, metaclass=abc.ABCMeta):
    """
    Abstract base class for non-affine curved transforms
    """

    input_dims = 2
    output_dims = 2
    is_separable = False

    def transform_path(self, path):
        """
        Transform a Matplotlib Path

        Parameters
        ----------
        path : :class:`~matplotlib.path.Path`
            The path to transform

        Returns
        -------
        path : :class:`~matplotlib.path.Path`
            The resulting path
        """
        return Path(self.transform(path.vertices), path.codes)

    transform_path_non_affine = transform_path

    def transform(self, input):
        raise NotImplementedError("")

    def inverted(self):
        raise NotImplementedError("")


class CoordinateTransform(CurvedTransform):

    has_inverse = True

    def __init__(self, input_system, output_system):
        super().__init__()
        self._input_system_name = input_system
        self._output_system_name = output_system

        if isinstance(self._input_system_name, str):
            frame_cls = frame_transform_graph.lookup_name(self._input_system_name)
            if frame_cls is None:
                raise ValueError(f"Frame {self._input_system_name} not found")
            else:
                self.input_system = frame_cls()
        elif isinstance(self._input_system_name, BaseCoordinateFrame):
            self.input_system = self._input_system_name
        else:
            raise TypeError("input_system should be a WCS instance, string, or a coordinate frame instance")

        if isinstance(self._output_system_name, str):
            frame_cls = frame_transform_graph.lookup_name(self._output_system_name)
            if frame_cls is None:
                raise ValueError(f"Frame {self._output_system_name} not found")
            else:
                self.output_system = frame_cls()

        elif isinstance(self._output_system_name, BaseCoordinateFrame):
            self.output_system = self._output_system_name
        else:
            raise TypeError("output_system should be a WCS instance, string, or a coordinate frame instance")

        if self.output_system == self.input_system:
            self.same_frames = True
        else:
            self.same_frames = False

    @property
    def same_frames(self):
        return self._same_frames

    @same_frames.setter
    def same_frames(self, same_frames):
        self._same_frames = same_frames

    def transform(self, input_coords):
        """
        Transform one set of coordinates to another
        """
        if self.same_frames:
            return input_coords

        input_coords = input_coords*u.deg
        x_in, y_in = input_coords[:, 0], input_coords[:, 1]

        c_in = SkyCoord(UnitSphericalRepresentation(x_in, y_in),
                        frame=self.input_system)

        # We often need to transform arrays that contain NaN values, and filtering
        # out the NaN values would have a performance hit, so instead we just pass
        # on all values and just ignore Numpy warnings
        with np.errstate(all='ignore'):
            c_out = c_in.transform_to(self.output_system)

        lon = c_out.spherical.lon.deg
        lat = c_out.spherical.lat.deg

        return np.concatenate((lon[:, np.newaxis], lat[:, np.newaxis]), axis=1)

    transform_non_affine = transform

    def inverted(self):
        """
        Return the inverse of the transform
        """
        return CoordinateTransform(self._output_system_name, self._input_system_name)


class World2PixelTransform(CurvedTransform, metaclass=abc.ABCMeta):
    """
    Base transformation from world to pixel coordinates
    """

    has_inverse = True
    frame_in = None

    @property
    @abc.abstractmethod
    def input_dims(self):
        """
        The number of input world dimensions
        """

    @abc.abstractmethod
    def transform(self, world):
        """
        Transform world to pixel coordinates. You should pass in a NxM array
        where N is the number of points to transform, and M is the number of
        dimensions. This then returns the (x, y) pixel coordinates
        as a Nx2 array.
        """

    @abc.abstractmethod
    def inverted(self):
        """
        Return the inverse of the transform
        """


class Pixel2WorldTransform(CurvedTransform, metaclass=abc.ABCMeta):
    """
    Base transformation from pixel to world coordinates
    """

    has_inverse = True
    frame_out = None

    @property
    @abc.abstractmethod
    def output_dims(self):
        """
        The number of output world dimensions
        """

    @abc.abstractmethod
    def transform(self, pixel):
        """
        Transform pixel to world coordinates. You should pass in a Nx2 array
        of (x, y) pixel coordinates to transform to world coordinates. This
        will then return an NxM array where M is the number of dimensions.
        """

    @abc.abstractmethod
    def inverted(self):
        """
        Return the inverse of the transform
        """
