# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""
This module tests some of the methods related to YAML serialization.

Requires `pyyaml <https://pyyaml.org/>`_ to be installed.
"""

from io import StringIO

import pytest
import numpy as np

from astropy.coordinates import (SkyCoord, EarthLocation, Angle, Longitude, Latitude,
                                 SphericalRepresentation, UnitSphericalRepresentation,
                                 CartesianRepresentation, SphericalCosLatDifferential,
                                 SphericalDifferential, CartesianDifferential)
from astropy import units as u
from astropy.time import Time
from astropy.table import QTable, SerializedColumn
from astropy.coordinates.tests.test_representation import representation_equal

yaml = pytest.importorskip('yaml', minversion='3.12')

from astropy.io.misc.yaml import load, load_all, dump  # noqa


@pytest.mark.parametrize('c', [True, np.uint8(8), np.int16(4),
                               np.int32(1), np.int64(3), np.int64(2**63 - 1),
                               2.0, np.float64(),
                               3+4j, np.complex_(3 + 4j),
                               np.complex64(3 + 4j),
                               np.complex128(1. - 2**-52 + 1j * (1. - 2**-52))])
def test_numpy_types(c):
    cy = load(dump(c))
    assert c == cy


@pytest.mark.parametrize('c', [u.m, u.m / u.s, u.hPa, u.dimensionless_unscaled])
def test_unit(c):
    cy = load(dump(c))
    if isinstance(c, u.CompositeUnit):
        assert c == cy
    else:
        assert c is cy


@pytest.mark.parametrize('c', [u.Unit('bakers_dozen', 13*u.one),
                               u.def_unit('magic')])
def test_custom_unit(c):
    s = dump(c)
    with pytest.warns(u.UnitsWarning, match=f"'{c!s}' did not parse") as w:
        cy = load(s)
    assert len(w) == 1
    assert isinstance(cy, u.UnrecognizedUnit)
    assert str(cy) == str(c)

    with u.add_enabled_units(c):
        cy2 = load(s)
        assert cy2 is c


@pytest.mark.parametrize('c', [Angle('1 2 3', unit='deg'),
                               Longitude('1 2 3', unit='deg'),
                               Latitude('1 2 3', unit='deg'),
                               [[1], [3]] * u.m,
                               np.array([[1, 2], [3, 4]], order='F'),
                               np.array([[1, 2], [3, 4]], order='C'),
                               np.array([1, 2, 3, 4])[::2]])
def test_ndarray_subclasses(c):
    cy = load(dump(c))

    assert np.all(c == cy)
    assert c.shape == cy.shape
    assert type(c) is type(cy)

    cc = 'C_CONTIGUOUS'
    fc = 'F_CONTIGUOUS'
    if c.flags[cc] or c.flags[fc]:
        assert c.flags[cc] == cy.flags[cc]
        assert c.flags[fc] == cy.flags[fc]
    else:
        # Original was not contiguous but round-trip version
        # should be c-contig.
        assert cy.flags[cc]

    if hasattr(c, 'unit'):
        assert c.unit == cy.unit


def compare_coord(c, cy):
    assert c.shape == cy.shape
    assert c.frame.name == cy.frame.name

    assert list(c.get_frame_attr_names()) == list(cy.get_frame_attr_names())
    for attr in c.get_frame_attr_names():
        assert getattr(c, attr) == getattr(cy, attr)

    assert (list(c.representation_component_names) ==
            list(cy.representation_component_names))
    for name in c.representation_component_names:
        assert np.all(getattr(c, attr) == getattr(cy, attr))


@pytest.mark.parametrize('frame', ['fk4', 'altaz'])
def test_skycoord(frame):

    c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]],
                 unit='deg', frame=frame,
                 obstime=Time('2016-01-02'),
                 location=EarthLocation(1000, 2000, 3000, unit=u.km))
    cy = load(dump(c))
    compare_coord(c, cy)


@pytest.mark.parametrize('rep', [
    CartesianRepresentation(1*u.m, 2.*u.m, 3.*u.m),
    SphericalRepresentation([[1, 2], [3, 4]]*u.deg,
                            [[5, 6], [7, 8]]*u.deg,
                            10*u.pc),
    UnitSphericalRepresentation(0*u.deg, 10*u.deg),
    SphericalCosLatDifferential([[1.], [2.]]*u.mas/u.yr,
                                [4., 5.]*u.mas/u.yr,
                                [[[10]], [[20]]]*u.km/u.s),
    CartesianDifferential([10, 20, 30]*u.km/u.s),
    CartesianRepresentation(
        [1, 2, 3]*u.m,
        differentials=CartesianDifferential([10, 20, 30]*u.km/u.s)),
    SphericalRepresentation(
        [[1, 2], [3, 4]]*u.deg, [[5, 6], [7, 8]]*u.deg, 10*u.pc,
        differentials={
            's': SphericalDifferential([[0., 1.], [2., 3.]]*u.mas/u.yr,
                                       [[4., 5.], [6., 7.]]*u.mas/u.yr,
                                       10*u.km/u.s)})])
def test_representations(rep):
    rrep = load(dump(rep))
    assert np.all(representation_equal(rrep, rep))


def _get_time():
    t = Time([[1], [2]], format='cxcsec',
             location=EarthLocation(1000, 2000, 3000, unit=u.km))
    t.format = 'iso'
    t.precision = 5
    t.delta_ut1_utc = np.array([[3.0], [4.0]])
    t.delta_tdb_tt = np.array([[5.0], [6.0]])
    t.out_subfmt = 'date_hm'

    return t


def compare_time(t, ty):
    assert type(t) is type(ty)
    assert np.all(t == ty)
    for attr in ('shape', 'jd1', 'jd2', 'format', 'scale', 'precision', 'in_subfmt',
                 'out_subfmt', 'location', 'delta_ut1_utc', 'delta_tdb_tt'):
        assert np.all(getattr(t, attr) == getattr(ty, attr))


def test_time():
    t = _get_time()
    ty = load(dump(t))
    compare_time(t, ty)


def test_timedelta():
    t = _get_time()
    dt = t - t + 0.1234556 * u.s
    dty = load(dump(dt))

    assert type(dt) is type(dty)
    for attr in ('shape', 'jd1', 'jd2', 'format', 'scale'):
        assert np.all(getattr(dt, attr) == getattr(dty, attr))


def test_serialized_column():
    sc = SerializedColumn({'name': 'hello', 'other': 1, 'other2': 2.0})
    scy = load(dump(sc))

    assert sc == scy


def test_load_all():
    t = _get_time()
    unit = u.m / u.s
    c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]],
                 unit='deg', frame='fk4',
                 obstime=Time('2016-01-02'),
                 location=EarthLocation(1000, 2000, 3000, unit=u.km))

    # Make a multi-document stream
    out = ('---\n' + dump(t)
           + '---\n' + dump(unit)
           + '---\n' + dump(c))

    ty, unity, cy = list(load_all(out))

    compare_time(t, ty)
    compare_coord(c, cy)
    assert unity == unit


def test_ecsv_astropy_objects_in_meta():
    """
    Test that astropy core objects in ``meta`` are serialized.
    """
    t = QTable([[1, 2] * u.m, [4, 5]], names=['a', 'b'])
    tm = _get_time()
    c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]],
                 unit='deg', frame='fk4',
                 obstime=Time('2016-01-02'),
                 location=EarthLocation(1000, 2000, 3000, unit=u.km))
    unit = u.m / u.s

    t.meta = {'tm': tm, 'c': c, 'unit': unit}
    out = StringIO()
    t.write(out, format='ascii.ecsv')
    t2 = QTable.read(out.getvalue(), format='ascii.ecsv')

    compare_time(tm, t2.meta['tm'])
    compare_coord(c, t2.meta['c'])
    assert t2.meta['unit'] == unit
