# Licensed under a 3-clause BSD style license - see PYFITS.rst

import sys
import numpy as np

from .base import DTYPE2BITPIX, DELAYED
from .image import PrimaryHDU
from .table import _TableLikeHDU
from astropy.io.fits.column import Column, ColDefs, FITS2NUMPY
from astropy.io.fits.fitsrec import FITS_rec, FITS_record
from astropy.io.fits.util import _is_int, _is_pseudo_unsigned, _unsigned_zero

from astropy.utils import lazyproperty


class Group(FITS_record):
    """
    One group of the random group data.
    """

    def __init__(self, input, row=0, start=None, end=None, step=None,
                 base=None):
        super().__init__(input, row, start, end, step, base)

    @property
    def parnames(self):
        return self.array.parnames

    @property
    def data(self):
        # The last column in the coldefs is the data portion of the group
        return self.field(self.array._coldefs.names[-1])

    @lazyproperty
    def _unique(self):
        return _par_indices(self.parnames)

    def par(self, parname):
        """
        Get the group parameter value.
        """

        if _is_int(parname):
            result = self.array[self.row][parname]
        else:
            indx = self._unique[parname.upper()]
            if len(indx) == 1:
                result = self.array[self.row][indx[0]]

            # if more than one group parameter have the same name
            else:
                result = self.array[self.row][indx[0]].astype('f8')
                for i in indx[1:]:
                    result += self.array[self.row][i]

        return result

    def setpar(self, parname, value):
        """
        Set the group parameter value.
        """

        # TODO: It would be nice if, instead of requiring a multi-part value to
        # be an array, there were an *option* to automatically split the value
        # into multiple columns if it doesn't already fit in the array data
        # type.

        if _is_int(parname):
            self.array[self.row][parname] = value
        else:
            indx = self._unique[parname.upper()]
            if len(indx) == 1:
                self.array[self.row][indx[0]] = value

            # if more than one group parameter have the same name, the
            # value must be a list (or tuple) containing arrays
            else:
                if isinstance(value, (list, tuple)) and \
                   len(indx) == len(value):
                    for i in range(len(indx)):
                        self.array[self.row][indx[i]] = value[i]
                else:
                    raise ValueError('Parameter value must be a sequence with '
                                     '{} arrays/numbers.'.format(len(indx)))


class GroupData(FITS_rec):
    """
    Random groups data object.

    Allows structured access to FITS Group data in a manner analogous
    to tables.
    """

    _record_type = Group

    def __new__(cls, input=None, bitpix=None, pardata=None, parnames=[],
                bscale=None, bzero=None, parbscales=None, parbzeros=None):
        """
        Parameters
        ----------
        input : array or FITS_rec instance
            input data, either the group data itself (a
            `numpy.ndarray`) or a record array (`FITS_rec`) which will
            contain both group parameter info and the data.  The rest
            of the arguments are used only for the first case.

        bitpix : int
            data type as expressed in FITS ``BITPIX`` value (8, 16, 32,
            64, -32, or -64)

        pardata : sequence of array
            parameter data, as a list of (numeric) arrays.

        parnames : sequence of str
            list of parameter names.

        bscale : int
            ``BSCALE`` of the data

        bzero : int
            ``BZERO`` of the data

        parbscales : sequence of int
            list of bscales for the parameters

        parbzeros : sequence of int
            list of bzeros for the parameters
        """

        if not isinstance(input, FITS_rec):
            if pardata is None:
                npars = 0
            else:
                npars = len(pardata)

            if parbscales is None:
                parbscales = [None] * npars
            if parbzeros is None:
                parbzeros = [None] * npars

            if parnames is None:
                parnames = [f'PAR{idx + 1}' for idx in range(npars)]

            if len(parnames) != npars:
                raise ValueError('The number of parameter data arrays does '
                                 'not match the number of parameters.')

            unique_parnames = _unique_parnames(parnames + ['DATA'])

            if bitpix is None:
                bitpix = DTYPE2BITPIX[input.dtype.name]

            fits_fmt = GroupsHDU._bitpix2tform[bitpix]  # -32 -> 'E'
            format = FITS2NUMPY[fits_fmt]  # 'E' -> 'f4'
            data_fmt = f'{str(input.shape[1:])}{format}'
            formats = ','.join(([format] * npars) + [data_fmt])
            gcount = input.shape[0]

            cols = [Column(name=unique_parnames[idx], format=fits_fmt,
                           bscale=parbscales[idx], bzero=parbzeros[idx])
                    for idx in range(npars)]
            cols.append(Column(name=unique_parnames[-1], format=fits_fmt,
                               bscale=bscale, bzero=bzero))

            coldefs = ColDefs(cols)

            self = FITS_rec.__new__(cls,
                                    np.rec.array(None,
                                                 formats=formats,
                                                 names=coldefs.names,
                                                 shape=gcount))

            # By default the data field will just be 'DATA', but it may be
            # uniquified if 'DATA' is already used by one of the group names
            self._data_field = unique_parnames[-1]

            self._coldefs = coldefs
            self.parnames = parnames

            for idx, name in enumerate(unique_parnames[:-1]):
                column = coldefs[idx]
                # Note: _get_scale_factors is used here and in other cases
                # below to determine whether the column has non-default
                # scale/zero factors.
                # TODO: Find a better way to do this than using this interface
                scale, zero = self._get_scale_factors(column)[3:5]
                if scale or zero:
                    self._cache_field(name, pardata[idx])
                else:
                    np.rec.recarray.field(self, idx)[:] = pardata[idx]

            column = coldefs[self._data_field]
            scale, zero = self._get_scale_factors(column)[3:5]
            if scale or zero:
                self._cache_field(self._data_field, input)
            else:
                np.rec.recarray.field(self, npars)[:] = input
        else:
            self = FITS_rec.__new__(cls, input)
            self.parnames = None
        return self

    def __array_finalize__(self, obj):
        super().__array_finalize__(obj)
        if isinstance(obj, GroupData):
            self.parnames = obj.parnames
        elif isinstance(obj, FITS_rec):
            self.parnames = obj._coldefs.names

    def __getitem__(self, key):
        out = super().__getitem__(key)
        if isinstance(out, GroupData):
            out.parnames = self.parnames
        return out

    @property
    def data(self):
        """
        The raw group data represented as a multi-dimensional `numpy.ndarray`
        array.
        """

        # The last column in the coldefs is the data portion of the group
        return self.field(self._coldefs.names[-1])

    @lazyproperty
    def _unique(self):
        return _par_indices(self.parnames)

    def par(self, parname):
        """
        Get the group parameter values.
        """

        if _is_int(parname):
            result = self.field(parname)
        else:
            indx = self._unique[parname.upper()]
            if len(indx) == 1:
                result = self.field(indx[0])

            # if more than one group parameter have the same name
            else:
                result = self.field(indx[0]).astype('f8')
                for i in indx[1:]:
                    result += self.field(i)

        return result


class GroupsHDU(PrimaryHDU, _TableLikeHDU):
    """
    FITS Random Groups HDU class.

    See the :ref:`random-groups` section in the Astropy documentation for more
    details on working with this type of HDU.
    """

    _bitpix2tform = {8: 'B', 16: 'I', 32: 'J', 64: 'K', -32: 'E', -64: 'D'}
    _data_type = GroupData
    _data_field = 'DATA'
    """
    The name of the table record array field that will contain the group data
    for each group; 'DATA' by default, but may be preceded by any number of
    underscores if 'DATA' is already a parameter name
    """

    def __init__(self, data=None, header=None):
        super().__init__(data=data, header=header)
        if data is not DELAYED:
            self.update_header()

        # Update the axes; GROUPS HDUs should always have at least one axis
        if len(self._axes) <= 0:
            self._axes = [0]
            self._header['NAXIS'] = 1
            self._header.set('NAXIS1', 0, after='NAXIS')

    @classmethod
    def match_header(cls, header):
        keyword = header.cards[0].keyword
        return (keyword == 'SIMPLE' and 'GROUPS' in header and
                header['GROUPS'] is True)

    @lazyproperty
    def data(self):
        """
        The data of a random group FITS file will be like a binary table's
        data.
        """

        if self._axes == [0]:
            return

        data = self._get_tbdata()
        data._coldefs = self.columns
        data.parnames = self.parnames
        del self.columns
        return data

    @lazyproperty
    def parnames(self):
        """The names of the group parameters as described by the header."""

        pcount = self._header['PCOUNT']
        # The FITS standard doesn't really say what to do if a parname is
        # missing, so for now just assume that won't happen
        return [self._header['PTYPE' + str(idx + 1)] for idx in range(pcount)]

    @lazyproperty
    def columns(self):
        if self._has_data and hasattr(self.data, '_coldefs'):
            return self.data._coldefs

        format = self._bitpix2tform[self._header['BITPIX']]
        pcount = self._header['PCOUNT']
        parnames = []
        bscales = []
        bzeros = []

        for idx in range(pcount):
            bscales.append(self._header.get('PSCAL' + str(idx + 1), None))
            bzeros.append(self._header.get('PZERO' + str(idx + 1), None))
            parnames.append(self._header['PTYPE' + str(idx + 1)])

        formats = [format] * len(parnames)
        dim = [None] * len(parnames)

        # Now create columns from collected parameters, but first add the DATA
        # column too, to contain the group data.
        parnames.append('DATA')
        bscales.append(self._header.get('BSCALE'))
        bzeros.append(self._header.get('BZEROS'))
        data_shape = self.shape[:-1]
        formats.append(str(int(np.prod(data_shape))) + format)
        dim.append(data_shape)
        parnames = _unique_parnames(parnames)

        self._data_field = parnames[-1]

        cols = [Column(name=name, format=fmt, bscale=bscale, bzero=bzero,
                       dim=dim)
                for name, fmt, bscale, bzero, dim in
                zip(parnames, formats, bscales, bzeros, dim)]

        coldefs = ColDefs(cols)
        return coldefs

    @property
    def _nrows(self):
        if not self._data_loaded:
            # The number of 'groups' equates to the number of rows in the table
            # representation of the data
            return self._header.get('GCOUNT', 0)
        else:
            return len(self.data)

    @lazyproperty
    def _theap(self):
        # Only really a lazyproperty for symmetry with _TableBaseHDU
        return 0

    @property
    def is_image(self):
        return False

    @property
    def size(self):
        """
        Returns the size (in bytes) of the HDU's data part.
        """

        size = 0
        naxis = self._header.get('NAXIS', 0)

        # for random group image, NAXIS1 should be 0, so we skip NAXIS1.
        if naxis > 1:
            size = 1
            for idx in range(1, naxis):
                size = size * self._header['NAXIS' + str(idx + 1)]
            bitpix = self._header['BITPIX']
            gcount = self._header.get('GCOUNT', 1)
            pcount = self._header.get('PCOUNT', 0)
            size = abs(bitpix) * gcount * (pcount + size) // 8
        return size

    def update_header(self):
        old_naxis = self._header.get('NAXIS', 0)

        if self._data_loaded:
            if isinstance(self.data, GroupData):
                self._axes = list(self.data.data.shape)[1:]
                self._axes.reverse()
                self._axes = [0] + self._axes
                field0 = self.data.dtype.names[0]
                field0_code = self.data.dtype.fields[field0][0].name
            elif self.data is None:
                self._axes = [0]
                field0_code = 'uint8'  # For lack of a better default
            else:
                raise ValueError('incorrect array type')

            self._header['BITPIX'] = DTYPE2BITPIX[field0_code]

        self._header['NAXIS'] = len(self._axes)

        # add NAXISi if it does not exist
        for idx, axis in enumerate(self._axes):
            if (idx == 0):
                after = 'NAXIS'
            else:
                after = 'NAXIS' + str(idx)

            self._header.set('NAXIS' + str(idx + 1), axis, after=after)

        # delete extra NAXISi's
        for idx in range(len(self._axes) + 1, old_naxis + 1):
            try:
                del self._header['NAXIS' + str(idx)]
            except KeyError:
                pass

        if self._has_data and isinstance(self.data, GroupData):
            self._header.set('GROUPS', True,
                             after='NAXIS' + str(len(self._axes)))
            self._header.set('PCOUNT', len(self.data.parnames), after='GROUPS')
            self._header.set('GCOUNT', len(self.data), after='PCOUNT')

            column = self.data._coldefs[self._data_field]
            scale, zero = self.data._get_scale_factors(column)[3:5]
            if scale:
                self._header.set('BSCALE', column.bscale)
            if zero:
                self._header.set('BZERO', column.bzero)

            for idx, name in enumerate(self.data.parnames):
                self._header.set('PTYPE' + str(idx + 1), name)
                column = self.data._coldefs[idx]
                scale, zero = self.data._get_scale_factors(column)[3:5]
                if scale:
                    self._header.set('PSCAL' + str(idx + 1), column.bscale)
                if zero:
                    self._header.set('PZERO' + str(idx + 1), column.bzero)

        # Update the position of the EXTEND keyword if it already exists
        if 'EXTEND' in self._header:
            if len(self._axes):
                after = 'NAXIS' + str(len(self._axes))
            else:
                after = 'NAXIS'
            self._header.set('EXTEND', after=after)

    def _writedata_internal(self, fileobj):
        """
        Basically copy/pasted from `_ImageBaseHDU._writedata_internal()`, but
        we have to get the data's byte order a different way...

        TODO: Might be nice to store some indication of the data's byte order
        as an attribute or function so that we don't have to do this.
        """

        size = 0

        if self.data is not None:
            self.data._scale_back()

            # Based on the system type, determine the byteorders that
            # would need to be swapped to get to big-endian output
            if sys.byteorder == 'little':
                swap_types = ('<', '=')
            else:
                swap_types = ('<',)
            # deal with unsigned integer 16, 32 and 64 data
            if _is_pseudo_unsigned(self.data.dtype):
                # Convert the unsigned array to signed
                output = np.array(
                    self.data - _unsigned_zero(self.data.dtype),
                    dtype=f'>i{self.data.dtype.itemsize}')
                should_swap = False
            else:
                output = self.data
                fname = self.data.dtype.names[0]
                byteorder = self.data.dtype.fields[fname][0].str[0]
                should_swap = (byteorder in swap_types)

            if should_swap:
                if output.flags.writeable:
                    output.byteswap(True)
                    try:
                        fileobj.writearray(output)
                    finally:
                        output.byteswap(True)
                else:
                    # For read-only arrays, there is no way around making
                    # a byteswapped copy of the data.
                    fileobj.writearray(output.byteswap(False))
            else:
                fileobj.writearray(output)

            size += output.size * output.itemsize
        return size

    def _verify(self, option='warn'):
        errs = super()._verify(option=option)

        # Verify locations and values of mandatory keywords.
        self.req_cards('NAXIS', 2,
                       lambda v: (_is_int(v) and 1 <= v <= 999), 1,
                       option, errs)
        self.req_cards('NAXIS1', 3, lambda v: (_is_int(v) and v == 0), 0,
                       option, errs)

        after = self._header['NAXIS'] + 3
        pos = lambda x: x >= after

        self.req_cards('GCOUNT', pos, _is_int, 1, option, errs)
        self.req_cards('PCOUNT', pos, _is_int, 0, option, errs)
        self.req_cards('GROUPS', pos, lambda v: (v is True), True, option,
                       errs)
        return errs

    def _calculate_datasum(self):
        """
        Calculate the value for the ``DATASUM`` card in the HDU.
        """

        if self._has_data:

            # We have the data to be used.

            # Check the byte order of the data.  If it is little endian we
            # must swap it before calculating the datasum.
            # TODO: Maybe check this on a per-field basis instead of assuming
            # that all fields have the same byte order?
            byteorder = \
                self.data.dtype.fields[self.data.dtype.names[0]][0].str[0]

            if byteorder != '>':
                if self.data.flags.writeable:
                    byteswapped = True
                    d = self.data.byteswap(True)
                    d.dtype = d.dtype.newbyteorder('>')
                else:
                    # If the data is not writeable, we just make a byteswapped
                    # copy and don't bother changing it back after
                    d = self.data.byteswap(False)
                    d.dtype = d.dtype.newbyteorder('>')
                    byteswapped = False
            else:
                byteswapped = False
                d = self.data

            byte_data = d.view(type=np.ndarray, dtype=np.ubyte)

            cs = self._compute_checksum(byte_data)

            # If the data was byteswapped in this method then return it to
            # its original little-endian order.
            if byteswapped:
                d.byteswap(True)
                d.dtype = d.dtype.newbyteorder('<')

            return cs
        else:
            # This is the case where the data has not been read from the file
            # yet.  We can handle that in a generic manner so we do it in the
            # base class.  The other possibility is that there is no data at
            # all.  This can also be handled in a generic manner.
            return super()._calculate_datasum()

    def _summary(self):
        summary = super()._summary()
        name, ver, classname, length, shape, format, gcount = summary

        # Drop the first axis from the shape
        if shape:
            shape = shape[1:]

            if shape and all(shape):
                # Update the format
                format = self.columns[0].dtype.name

        # Update the GCOUNT report
        gcount = f'{self._gcount} Groups  {self._pcount} Parameters'
        return (name, ver, classname, length, shape, format, gcount)


def _par_indices(names):
    """
    Given a list of objects, returns a mapping of objects in that list to the
    index or indices at which that object was found in the list.
    """

    unique = {}
    for idx, name in enumerate(names):
        # Case insensitive
        name = name.upper()
        if name in unique:
            unique[name].append(idx)
        else:
            unique[name] = [idx]
    return unique


def _unique_parnames(names):
    """
    Given a list of parnames, including possible duplicates, returns a new list
    of parnames with duplicates prepended by one or more underscores to make
    them unique.  This is also case insensitive.
    """

    upper_names = set()
    unique_names = []

    for name in names:
        name_upper = name.upper()
        while name_upper in upper_names:
            name = '_' + name
            name_upper = '_' + name_upper

        unique_names.append(name)
        upper_names.add(name_upper)

    return unique_names
