# Licensed under a 3-clause BSD style license - see LICENSE.rst
# This module contains tests of a class equivalent to pre-1.0 NDData.


import pytest
import numpy as np

from astropy.nddata.nddata import NDData
from astropy.nddata.compat import NDDataArray
from astropy.nddata.nduncertainty import StdDevUncertainty
from astropy.wcs import WCS
from astropy import units as u


NDDATA_ATTRIBUTES = ['mask', 'flags', 'uncertainty', 'unit', 'shape', 'size',
                     'dtype', 'ndim', 'wcs', 'convert_unit_to']


def test_nddataarray_has_attributes_of_old_nddata():
    ndd = NDDataArray([1, 2, 3])
    for attr in NDDATA_ATTRIBUTES:
        assert hasattr(ndd, attr)


def test_nddata_simple():
    nd = NDDataArray(np.zeros((10, 10)))
    assert nd.shape == (10, 10)
    assert nd.size == 100
    assert nd.dtype == np.dtype(float)


def test_nddata_parameters():
    # Test for issue 4620
    nd = NDDataArray(data=np.zeros((10, 10)))
    assert nd.shape == (10, 10)
    assert nd.size == 100
    assert nd.dtype == np.dtype(float)
    # Change order; `data` has to be given explicitly here
    nd = NDDataArray(meta={}, data=np.zeros((10, 10)))
    assert nd.shape == (10, 10)
    assert nd.size == 100
    assert nd.dtype == np.dtype(float)
    # Pass uncertainty as second implicit argument
    data = np.zeros((10, 10))
    uncertainty = StdDevUncertainty(0.1 + np.zeros_like(data))
    nd = NDDataArray(data, uncertainty)
    assert nd.shape == (10, 10)
    assert nd.size == 100
    assert nd.dtype == np.dtype(float)
    assert nd.uncertainty == uncertainty


def test_nddata_conversion():
    nd = NDDataArray(np.array([[1, 2, 3], [4, 5, 6]]))
    assert nd.size == 6
    assert nd.dtype == np.dtype(int)


@pytest.mark.parametrize('flags_in', [
                         np.array([True, False]),
                         np.array([1, 0]),
                         [True, False],
                         [1, 0],
                         np.array(['a', 'b']),
                         ['a', 'b']])
def test_nddata_flags_init_without_np_array(flags_in):
    ndd = NDDataArray([1, 1], flags=flags_in)
    assert (ndd.flags == flags_in).all()


@pytest.mark.parametrize(('shape'), [(10,), (5, 5), (3, 10, 10)])
def test_nddata_flags_invalid_shape(shape):
    with pytest.raises(ValueError) as exc:
        NDDataArray(np.zeros((10, 10)), flags=np.ones(shape))
    assert exc.value.args[0] == 'dimensions of flags do not match data'


def test_convert_unit_to():
    # convert_unit_to should return a copy of its input
    d = NDDataArray(np.ones((5, 5)))
    d.unit = 'km'
    d.uncertainty = StdDevUncertainty(0.1 + np.zeros_like(d))
    # workaround because zeros_like does not support dtype arg until v1.6
    # and NDData accepts only bool ndarray as mask
    tmp = np.zeros_like(d.data)
    d.mask = np.array(tmp, dtype=bool)
    d1 = d.convert_unit_to('m')
    assert np.all(d1.data == np.array(1000.0))
    assert np.all(d1.uncertainty.array == 1000.0 * d.uncertainty.array)
    assert d1.unit == u.m
    # changing the output mask should not change the original
    d1.mask[0, 0] = True
    assert d.mask[0, 0] != d1.mask[0, 0]
    d.flags = np.zeros_like(d.data)
    d1 = d.convert_unit_to('m')


# check that subclasses can require wcs and/or unit to be present and use
# _arithmetic and convert_unit_to
class SubNDData(NDDataArray):
    """
    Subclass for test initialization of subclasses in NDData._arithmetic and
    NDData.convert_unit_to
    """
    def __init__(self, *arg, **kwd):
        super().__init__(*arg, **kwd)
        if self.unit is None:
            raise ValueError("Unit for subclass must be specified")
        if self.wcs is None:
            raise ValueError("WCS for subclass must be specified")


def test_init_of_subclass_in_convert_unit_to():
    data = np.ones([10, 10])
    arr1 = SubNDData(data, unit='m', wcs=WCS(naxis=2))
    result = arr1.convert_unit_to('km')
    np.testing.assert_array_equal(arr1.data, 1000 * result.data)


# Test for issue #4129:
def test_nddataarray_from_nddataarray():
    ndd1 = NDDataArray([1., 4., 9.],
                       uncertainty=StdDevUncertainty([1., 2., 3.]),
                       flags=[0, 1, 0])
    ndd2 = NDDataArray(ndd1)
    # Test that the 2 instances point to the same objects and aren't just
    # equal; this is explicitly documented for the main data array and we
    # probably want to catch any future change in behavior for the other
    # attributes too and ensure they are intentional.
    assert ndd2.data is ndd1.data
    assert ndd2.uncertainty is ndd1.uncertainty
    assert ndd2.flags is ndd1.flags
    assert ndd2.meta == ndd1.meta


# Test for issue #4137:
def test_nddataarray_from_nddata():
    ndd1 = NDData([1., 4., 9.],
                  uncertainty=StdDevUncertainty([1., 2., 3.]))
    ndd2 = NDDataArray(ndd1)

    assert ndd2.data is ndd1.data
    assert ndd2.uncertainty is ndd1.uncertainty
    assert ndd2.meta == ndd1.meta
