# Licensed under a 3-clause BSD style license - see LICENSE.rst

import warnings

import pytest
import numpy as np

from .test_table import SetupData
from astropy.table.bst import BST
from astropy.table.sorted_array import SortedArray
from astropy.table.soco import SCEngine
from astropy.table import QTable, Row, Table, Column, hstack
from astropy import units as u
from astropy.time import Time
from astropy.table.column import BaseColumn
from astropy.table.index import get_index, SlicedIndex
from astropy.utils.compat.optional_deps import HAS_SORTEDCONTAINERS

available_engines = [BST, SortedArray]

if HAS_SORTEDCONTAINERS:
    available_engines.append(SCEngine)


@pytest.fixture(params=available_engines)
def engine(request):
    return request.param


_col = [1, 2, 3, 4, 5]


@pytest.fixture(params=[
    _col,
    u.Quantity(_col),
    Time(_col, format='jyear'),
])
def main_col(request):
    return request.param


def assert_col_equal(col, array):
    if isinstance(col, Time):
        assert np.all(col == Time(array, format='jyear'))
    else:
        assert np.all(col == col.__class__(array))


@pytest.mark.usefixtures('table_types')
class TestIndex(SetupData):
    def _setup(self, main_col, table_types):
        super()._setup(table_types)
        self.main_col = main_col
        if isinstance(main_col, u.Quantity):
            self._table_type = QTable
        if not isinstance(main_col, list):
            self._column_type = lambda x: x  # don't change mixin type
        self.mutable = isinstance(main_col, (list, u.Quantity))

    def make_col(self, name, lst):
        return self._column_type(lst, name=name)

    def make_val(self, val):
        if isinstance(self.main_col, Time):
            return Time(val, format='jyear')
        return val

    @property
    def t(self):
        if not hasattr(self, '_t'):
            # Note that order of columns is important, and the 'a' column is
            # last to ensure that the index column does not need to be the first
            # column (as was discovered in #10025).  Most testing uses 'a' and
            # ('a', 'b') for the columns.
            self._t = self._table_type()
            self._t['b'] = self._column_type([4.0, 5.1, 6.2, 7.0, 1.1])
            self._t['c'] = self._column_type(['7', '8', '9', '10', '11'])
            self._t['a'] = self._column_type(self.main_col)
        return self._t

    @pytest.mark.parametrize("composite", [False, True])
    def test_table_index(self, main_col, table_types, composite, engine):
        self._setup(main_col, table_types)
        t = self.t
        t.add_index(('a', 'b') if composite else 'a', engine=engine)
        assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])

        if not self.mutable:
            return

        # test altering table columns
        t['a'][0] = 4
        t.add_row((6.0, '7', 6))
        t['a'][3] = 10
        t.remove_row(2)
        t.add_row((5.0, '9', 4))

        assert_col_equal(t['a'], np.array([4, 2, 10, 5, 6, 4]))
        assert np.allclose(t['b'], np.array([4.0, 5.1, 7.0, 1.1, 6.0, 5.0]))
        assert np.all(t['c'].data == np.array(['7', '8', '10', '11', '7', '9']))
        index = t.indices[0]
        ll = list(index.data.items())

        if composite:
            assert np.all(ll == [((2, 5.1), [1]),
                                 ((4, 4.0), [0]),
                                 ((4, 5.0), [5]),
                                 ((5, 1.1), [3]),
                                 ((6, 6.0), [4]),
                                 ((10, 7.0), [2])])
        else:
            assert np.all(ll == [((2,), [1]),
                                 ((4,), [0, 5]),
                                 ((5,), [3]),
                                 ((6,), [4]),
                                 ((10,), [2])])
        t.remove_indices('a')
        assert len(t.indices) == 0

    def test_table_slicing(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t
        t.add_index('a', engine=engine)
        assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])

        for slice_ in ([0, 2], np.array([0, 2])):
            t2 = t[slice_]
            # t2 should retain an index on column 'a'
            assert len(t2.indices) == 1
            assert_col_equal(t2['a'], [1, 3])

            # the index in t2 should reorder row numbers after slicing
            assert np.all(t2.indices[0].sorted_data() == [0, 1])
            # however, this index should be a deep copy of t1's index
            assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])

    def test_remove_rows(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        if not self.mutable:
            return
        t = self.t
        t.add_index('a', engine=engine)

        # remove individual row
        t2 = t.copy()
        t2.remove_rows(2)
        assert_col_equal(t2['a'], [1, 2, 4, 5])
        assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3])

        # remove by list, ndarray, or slice
        for cut in ([0, 2, 4], np.array([0, 2, 4]), slice(0, 5, 2)):
            t2 = t.copy()
            t2.remove_rows(cut)
            assert_col_equal(t2['a'], [2, 4])
            assert np.all(t2.indices[0].sorted_data() == [0, 1])

        with pytest.raises(ValueError):
            t.remove_rows((0, 2, 4))

    def test_col_get_slice(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t
        t.add_index('a', engine=engine)

        # get slice
        t2 = t[1:3]  # table slice
        assert_col_equal(t2['a'], [2, 3])
        assert np.all(t2.indices[0].sorted_data() == [0, 1])

        col_slice = t['a'][1:3]
        assert_col_equal(col_slice, [2, 3])
        # true column slices discard indices
        if isinstance(t['a'], BaseColumn):
            assert len(col_slice.info.indices) == 0

        # take slice of slice
        t2 = t[::2]
        assert_col_equal(t2['a'], np.array([1, 3, 5]))
        t3 = t2[::-1]
        assert_col_equal(t3['a'], np.array([5, 3, 1]))
        assert np.all(t3.indices[0].sorted_data() == [2, 1, 0])
        t3 = t2[:2]
        assert_col_equal(t3['a'], np.array([1, 3]))
        assert np.all(t3.indices[0].sorted_data() == [0, 1])
        # out-of-bound slices
        for t_empty in (t2[3:], t2[2:1], t3[2:]):
            assert len(t_empty['a']) == 0
            assert np.all(t_empty.indices[0].sorted_data() == [])

        if self.mutable:
            # get boolean mask
            mask = t['a'] % 2 == 1
            t2 = t[mask]
            assert_col_equal(t2['a'], [1, 3, 5])
            assert np.all(t2.indices[0].sorted_data() == [0, 1, 2])

    def test_col_set_slice(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        if not self.mutable:
            return
        t = self.t
        t.add_index('a', engine=engine)

        # set slice
        t2 = t.copy()
        t2['a'][1:3] = np.array([6, 7])
        assert_col_equal(t2['a'], np.array([1, 6, 7, 4, 5]))
        assert np.all(t2.indices[0].sorted_data() == [0, 3, 4, 1, 2])

        # change original table via slice reference
        t2 = t.copy()
        t3 = t2[1:3]
        assert_col_equal(t3['a'], np.array([2, 3]))
        assert np.all(t3.indices[0].sorted_data() == [0, 1])
        t3['a'][0] = 5
        assert_col_equal(t3['a'], np.array([5, 3]))
        assert_col_equal(t2['a'], np.array([1, 5, 3, 4, 5]))
        assert np.all(t3.indices[0].sorted_data() == [1, 0])
        assert np.all(t2.indices[0].sorted_data() == [0, 2, 3, 1, 4])

        # set boolean mask
        t2 = t.copy()
        mask = t['a'] % 2 == 1
        t2['a'][mask] = 0.
        assert_col_equal(t2['a'], [0, 2, 0, 4, 0])
        assert np.all(t2.indices[0].sorted_data() == [0, 2, 4, 1, 3])

    def test_multiple_slices(self, main_col, table_types, engine):
        self._setup(main_col, table_types)

        if not self.mutable:
            return

        t = self.t
        t.add_index('a', engine=engine)

        for i in range(6, 51):
            t.add_row((1.0, 'A', i))

        assert_col_equal(t['a'], [i for i in range(1, 51)])
        assert np.all(t.indices[0].sorted_data() == [i for i in range(50)])

        evens = t[::2]
        assert np.all(evens.indices[0].sorted_data() == [i for i in range(25)])
        reverse = evens[::-1]
        index = reverse.indices[0]
        assert (index.start, index.stop, index.step) == (48, -2, -2)
        assert np.all(index.sorted_data() == [i for i in range(24, -1, -1)])

        # modify slice of slice
        reverse[-10:] = 0
        expected = np.array([i for i in range(1, 51)])
        expected[:20][expected[:20] % 2 == 1] = 0
        assert_col_equal(t['a'], expected)
        assert_col_equal(evens['a'], expected[::2])
        assert_col_equal(reverse['a'], expected[::2][::-1])
        # first ten evens are now zero
        assert np.all(t.indices[0].sorted_data()
                      == ([0, 2, 4, 6, 8, 10, 12, 14, 16, 18,
                           1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
                          + [i for i in range(20, 50)]))
        assert np.all(evens.indices[0].sorted_data() == [i for i in range(25)])
        assert np.all(reverse.indices[0].sorted_data()
                      == [i for i in range(24, -1, -1)])

        # try different step sizes of slice
        t2 = t[1:20:2]
        assert_col_equal(t2['a'], [2, 4, 6, 8, 10, 12, 14, 16, 18, 20])
        assert np.all(t2.indices[0].sorted_data() == [i for i in range(10)])
        t3 = t2[::3]
        assert_col_equal(t3['a'], [2, 8, 14, 20])
        assert np.all(t3.indices[0].sorted_data() == [0, 1, 2, 3])
        t4 = t3[2::-1]
        assert_col_equal(t4['a'], [14, 8, 2])
        assert np.all(t4.indices[0].sorted_data() == [2, 1, 0])

    def test_sort(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t[::-1]  # reverse table
        assert_col_equal(t['a'], [5, 4, 3, 2, 1])
        t.add_index('a', engine=engine)
        assert np.all(t.indices[0].sorted_data() == [4, 3, 2, 1, 0])

        if not self.mutable:
            return

        # sort table by column a
        t2 = t.copy()
        t2.sort('a')
        assert_col_equal(t2['a'], [1, 2, 3, 4, 5])
        assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3, 4])

        # sort table by primary key
        t2 = t.copy()
        t2.sort()
        assert_col_equal(t2['a'], [1, 2, 3, 4, 5])
        assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3, 4])

    def test_insert_row(self, main_col, table_types, engine):
        self._setup(main_col, table_types)

        if not self.mutable:
            return

        t = self.t
        t.add_index('a', engine=engine)
        t.insert_row(2, (1.0, '12', 6))
        assert_col_equal(t['a'], [1, 2, 6, 3, 4, 5])
        assert np.all(t.indices[0].sorted_data() == [0, 1, 3, 4, 5, 2])
        t.insert_row(1, (4.0, '13', 0))
        assert_col_equal(t['a'], [1, 0, 2, 6, 3, 4, 5])
        assert np.all(t.indices[0].sorted_data() == [1, 0, 2, 4, 5, 6, 3])

    def test_index_modes(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t
        t.add_index('a', engine=engine)

        # first, no special mode
        assert len(t[[1, 3]].indices) == 1
        assert len(t[::-1].indices) == 1
        assert len(self._table_type(t).indices) == 1
        assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])
        t2 = t.copy()

        # non-copy mode
        with t.index_mode('discard_on_copy'):
            assert len(t[[1, 3]].indices) == 0
            assert len(t[::-1].indices) == 0
            assert len(self._table_type(t).indices) == 0
            assert len(t2.copy().indices) == 1  # mode should only affect t

        # make sure non-copy mode is exited correctly
        assert len(t[[1, 3]].indices) == 1

        if not self.mutable:
            return

        # non-modify mode
        with t.index_mode('freeze'):
            assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])
            t['a'][0] = 6
            assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])
            t.add_row((1.5, '12', 2))
            assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])
            t.remove_rows([1, 3])
            assert np.all(t.indices[0].sorted_data() == [0, 1, 2, 3, 4])
            assert_col_equal(t['a'], [6, 3, 5, 2])
            # mode should only affect t
            assert np.all(t2.indices[0].sorted_data() == [0, 1, 2, 3, 4])
            t2['a'][0] = 6
            assert np.all(t2.indices[0].sorted_data() == [1, 2, 3, 4, 0])

        # make sure non-modify mode is exited correctly
        assert np.all(t.indices[0].sorted_data() == [3, 1, 2, 0])

        if isinstance(t['a'], BaseColumn):
            assert len(t['a'][::-1].info.indices) == 0
            with t.index_mode('copy_on_getitem'):
                assert len(t['a'][[1, 2]].info.indices) == 1
                # mode should only affect t
                assert len(t2['a'][[1, 2]].info.indices) == 0

            assert len(t['a'][::-1].info.indices) == 0
            assert len(t2['a'][::-1].info.indices) == 0

    def test_index_retrieval(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t
        t.add_index('a', engine=engine)
        t.add_index(['a', 'c'], engine=engine)
        assert len(t.indices) == 2
        assert len(t.indices['a'].columns) == 1
        assert len(t.indices['a', 'c'].columns) == 2

        with pytest.raises(IndexError):
            t.indices['b']

    def test_col_rename(self, main_col, table_types, engine):
        '''
        Checks for a previous bug in which copying a Table
        with different column names raised an exception.
        '''
        self._setup(main_col, table_types)
        t = self.t
        t.add_index('a', engine=engine)
        t2 = self._table_type(self.t, names=['d', 'e', 'f'])
        assert len(t2.indices) == 1

    def test_table_loc(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t

        t.add_index('a', engine=engine)
        t.add_index('b', engine=engine)

        t2 = t.loc[self.make_val(3)]  # single label, with primary key 'a'
        assert_col_equal(t2['a'], [3])
        assert isinstance(t2, Row)

        # list search
        t2 = t.loc[[self.make_val(1), self.make_val(4), self.make_val(2)]]
        assert_col_equal(t2['a'], [1, 4, 2])  # same order as input list
        if not isinstance(main_col, Time):
            # ndarray search
            t2 = t.loc[np.array([1, 4, 2])]
            assert_col_equal(t2['a'], [1, 4, 2])
        assert_col_equal(t2['a'], [1, 4, 2])
        t2 = t.loc[self.make_val(3): self.make_val(5)]  # range search
        assert_col_equal(t2['a'], [3, 4, 5])
        t2 = t.loc['b', 5.0:7.0]
        assert_col_equal(t2['b'], [5.1, 6.2, 7.0])
        # search by sorted index
        t2 = t.iloc[0:2]  # two smallest rows by column 'a'
        assert_col_equal(t2['a'], [1, 2])
        t2 = t.iloc['b', 2:]  # exclude two smallest rows in column 'b'
        assert_col_equal(t2['b'], [5.1, 6.2, 7.0])

        for t2 in (t.loc[:], t.iloc[:]):
            assert_col_equal(t2['a'], [1, 2, 3, 4, 5])

    def test_table_loc_indices(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t

        t.add_index('a', engine=engine)
        t.add_index('b', engine=engine)

        t2 = t.loc_indices[self.make_val(3)]  # single label, with primary key 'a'
        assert t2 == 2

        # list search
        t2 = t.loc_indices[[self.make_val(1), self.make_val(4), self.make_val(2)]]
        for i, p in zip(t2, [1, 4, 2]):  # same order as input list
            assert i == p - 1

    def test_invalid_search(self, main_col, table_types, engine):
        # using .loc and .loc_indices with a value not present should raise an exception
        self._setup(main_col, table_types)
        t = self.t

        t.add_index('a')
        with pytest.raises(KeyError):
            t.loc[self.make_val(6)]
        with pytest.raises(KeyError):
            t.loc_indices[self.make_val(6)]

    def test_copy_index_references(self, main_col, table_types, engine):
        # check against a bug in which indices were given an incorrect
        # column reference when copied
        self._setup(main_col, table_types)
        t = self.t

        t.add_index('a')
        t.add_index('b')
        t2 = t.copy()
        assert t2.indices['a'].columns[0] is t2['a']
        assert t2.indices['b'].columns[0] is t2['b']

    def test_unique_index(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = self.t

        t.add_index('a', engine=engine, unique=True)
        assert np.all(t.indices['a'].sorted_data() == [0, 1, 2, 3, 4])

        if self.mutable:
            with pytest.raises(ValueError):
                t.add_row((5.0, '9', 5))

    def test_copy_indexed_table(self, table_types):
        self._setup(_col, table_types)
        t = self.t
        t.add_index('a')
        t.add_index(['a', 'b'])
        for tp in (self._table_type(t), t.copy()):
            assert len(t.indices) == len(tp.indices)
            for index, indexp in zip(t.indices, tp.indices):
                assert np.all(index.data.data == indexp.data.data)
                assert index.data.data.colnames == indexp.data.data.colnames

    def test_updating_row_byindex(self, main_col, table_types, engine):
        self._setup(main_col, table_types)
        t = Table([['a', 'b', 'c', 'd'], [2, 3, 4, 5], [3, 4, 5, 6]],
                  names=('a', 'b', 'c'), meta={'name': 'first table'})

        t.add_index('a', engine=engine)
        t.add_index('b', engine=engine)

        t.loc['c'] = ['g', 40, 50]  # single label, with primary key 'a'
        t2 = t[2]
        assert list(t2) == ['g', 40, 50]

        # list search
        t.loc[['a', 'd', 'b']] = [['a', 20, 30], ['d', 50, 60], ['b', 30, 40]]
        t2 = [['a', 20, 30], ['d', 50, 60], ['b', 30, 40]]
        for i, p in zip(t2, [1, 4, 2]):  # same order as input list
            assert list(t[p - 1]) == i

    def test_invalid_updates(self, main_col, table_types, engine):
        # using .loc and .loc_indices with a value not present should raise an exception
        self._setup(main_col, table_types)
        t = Table([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]],
                  names=('a', 'b', 'c'), meta={'name': 'first table'})

        t.add_index('a')
        with pytest.raises(ValueError):
            t.loc[3] = [[1, 2, 3]]
        with pytest.raises(ValueError):
            t.loc[[1, 4, 2]] = [[1, 2, 3], [4, 5, 6]]
        with pytest.raises(ValueError):
            t.loc[[1, 4, 2]] = [[1, 2, 3], [4, 5, 6], [2, 3]]
        with pytest.raises(ValueError):
            t.loc[[1, 4, 2]] = [[1, 2, 3], [4, 5], [2, 3]]


def test_get_index():
    a = [1, 4, 5, 2, 7, 4, 45]
    b = [2.0, 5.0, 8.2, 3.7, 4.3, 6.5, 3.3]
    t = Table([a, b], names=('a', 'b'), meta={'name': 'first table'})
    t.add_index(['a'])
    # Getting the values of index using names
    x1 = get_index(t, names=['a'])

    assert isinstance(x1, SlicedIndex)
    assert len(x1.columns) == 1
    assert len(x1.columns[0]) == 7
    assert x1.columns[0].info.name == 'a'
    # Getting the vales of index using table_copy
    x2 = get_index(t, table_copy=t[['a']])

    assert isinstance(x2, SlicedIndex)
    assert len(x2.columns) == 1
    assert len(x2.columns[0]) == 7
    assert x2.columns[0].info.name == 'a'

    with pytest.raises(ValueError):
        get_index(t, names=['a'], table_copy=t[['a']])
    with pytest.raises(ValueError):
        get_index(t, names=None, table_copy=None)


def test_table_index_time_warning(engine):
    # Make sure that no ERFA warnings are emitted when indexing a table by
    # a Time column with a non-default time scale
    tab = Table()
    tab['a'] = Time([1, 2, 3], format='jyear', scale='tai')
    tab['b'] = [4, 3, 2]
    with warnings.catch_warnings(record=True) as wlist:
        tab.add_index(('a', 'b'), engine=engine)
    assert len(wlist) == 0


@pytest.mark.parametrize('col', [
    Column(np.arange(50000, 50005)),
    np.arange(50000, 50005) * u.m,
    Time(np.arange(50000, 50005), format='mjd')])
def test_table_index_does_not_propagate_to_column_slices(col):
    # They lost contact to the parent table, so they should also not have
    # information on the indices; this helps prevent large memory usage if,
    # e.g., a large time column is turned into an object array; see gh-10688.
    tab = QTable()
    tab['t'] = col
    tab.add_index('t')
    t = tab['t']
    assert t.info.indices
    tx = t[1:]
    assert not tx.info.indices
    tabx = tab[1:]
    t = tabx['t']
    assert t.info.indices


def test_hstack_qtable_table():
    # Check in particular that indices are initialized or copied correctly
    # for a Column that is being converted to a Quantity.
    qtab = QTable([np.arange(5.)*u.m], names=['s'])
    qtab.add_index('s')
    tab = Table([Column(np.arange(5.), unit=u.s)], names=['t'])
    qstack = hstack([qtab, tab])
    assert qstack['t'].info.indices == []
    assert qstack.indices == []


def test_index_slice_exception():
    with pytest.raises(TypeError, match='index_slice must be tuple or slice'):
        SlicedIndex(None, None)
