import astropy_healpix as ah
from astropy import table
import numpy as np
import pytest

from .. import moc


@pytest.mark.parametrize('order', [0, 1, 2])
@pytest.mark.parametrize('ipix', [-1, -2, -3])
def test_nest2uniq_invalid(order, ipix):
    """Test nest2uniq for invalid values."""
    assert moc.nest2uniq(np.int8(order), ipix) == -1


@pytest.mark.parametrize('uniq', [-1, 0, 2, 3])
def test_uniq2order_invalid(uniq):
    """Test uniq2order for invalid values."""
    assert moc.uniq2order(uniq) == -1


@pytest.mark.parametrize('uniq', [-1, 0, 2, 3])
def test_uniq2pixarea_invalid(uniq):
    """Test uniq2order for invalid values."""
    assert np.isnan(moc.uniq2pixarea(uniq))


@pytest.mark.parametrize('uniq', [-1, 0, 2, 3])
def test_uniq2nest_invalid(uniq):
    """Test uniq2order for invalid values."""
    order, nest = moc.uniq2nest(uniq)
    assert order == -1
    assert nest == -1


@pytest.mark.parametrize('uniq', [-1, 0, 2, 3])
def test_uniq2ang_invalid(uniq):
    """Test uniq2order for invalid values."""
    theta, phi = moc.uniq2ang(uniq)
    assert np.isnan(theta)
    assert np.isnan(phi)


def input_skymap(order1, d_order, fraction):
    """Construct a test multi-resolution sky map, with values that are
    proportional to the NESTED pixel index.

    To make the test more interesting by mixing together multiple resolutions,
    part of the sky map is refined to a higher order.

    Parameters
    ----------
    order1 : int
        The HEALPix resolution order.
    d_order : int
        The increase in orer for part of the sky map.
    fraction : float
        The fraction of the original pixels to refine.

    """
    order2 = order1 + d_order
    npix1 = ah.nside_to_npix(ah.level_to_nside(order1))
    npix2 = ah.nside_to_npix(ah.level_to_nside(order2))
    ipix1 = np.arange(npix1)
    ipix2 = np.arange(npix2)

    data1 = table.Table({
        'UNIQ': moc.nest2uniq(order1, ipix1),
        'VALUE': ipix1.astype(float),
        'VALUE2': np.pi * ipix1.astype(float)
    })

    data2 = table.Table({
        'UNIQ': moc.nest2uniq(order2, ipix2),
        'VALUE': np.repeat(ipix1, npix2 // npix1).astype(float),
        'VALUE2': np.pi * np.repeat(ipix1, npix2 // npix1).astype(float)
    })

    n = int(npix1 * (1 - fraction))
    return table.vstack((data1[:n], data2[n * npix2 // npix1:]))


def test_rasterize_oom():
    """Test that rasterize() will correctly raise a MemoryError if it runs out
    of memory.
    """
    # A pixel at the highest possible 64-bit HEALPix resolution.
    uniq = moc.nest2uniq(np.int8(29), 0)
    data = table.Table({'UNIQ': [uniq], 'VALUE': [0]})
    with pytest.raises(MemoryError):
        moc._rasterize(data)


@pytest.mark.parametrize('order_in', [6])
@pytest.mark.parametrize('d_order_in', range(3))
@pytest.mark.parametrize('fraction_in', [0, 0.25, 0.5, 1])
@pytest.mark.parametrize('order_out', range(6))
def test_rasterize_downsample(order_in, d_order_in, fraction_in, order_out):
    npix_in = ah.nside_to_npix(ah.level_to_nside(order_in))
    npix_out = ah.nside_to_npix(ah.level_to_nside(order_out))
    skymap_in = input_skymap(order_in, d_order_in, fraction_in)
    skymap_out = moc.rasterize(skymap_in, order_out)

    assert len(skymap_out) == npix_out
    reps = npix_in // npix_out
    expected = np.mean(np.arange(npix_in).reshape(-1, reps), axis=1)
    np.testing.assert_array_equal(skymap_out['VALUE'], expected)


@pytest.mark.parametrize('order_in', [2])
@pytest.mark.parametrize('d_order_in', range(3))
@pytest.mark.parametrize('fraction_in', [0, 0.25, 0.5, 1])
@pytest.mark.parametrize('order_out', range(3, 9))
def test_rasterize_upsample(order_in, d_order_in, fraction_in, order_out):
    npix_in = ah.nside_to_npix(ah.level_to_nside(order_in))
    npix_out = ah.nside_to_npix(ah.level_to_nside(order_out))
    skymap_in = input_skymap(order_in, d_order_in, fraction_in)
    skymap_out = moc.rasterize(skymap_in, order_out)

    assert len(skymap_out) == npix_out
    ipix = np.arange(npix_in)
    reps = npix_out // npix_in
    for i in range(reps):
        np.testing.assert_array_equal(skymap_out['VALUE'][i::reps], ipix)


@pytest.mark.parametrize('order', range(3))
def test_rasterize_default(order):
    npix = ah.nside_to_npix(ah.level_to_nside(order))
    skymap_in = input_skymap(order, 0, 0)
    skymap_out = moc.rasterize(skymap_in)
    assert len(skymap_out) == npix
