# -*- coding: utf-8 -*-
# Licensed under a 3-clause BSD style license - see LICENSE.rst


from astropy import table
from astropy.table import pprint


class MyRow(table.Row):
    def __str__(self):
        return str(self.as_void())


class MyColumn(table.Column):
    pass


class MyMaskedColumn(table.MaskedColumn):
    pass


class MyTableColumns(table.TableColumns):
    pass


class MyTableFormatter(pprint.TableFormatter):
    pass


class MyTable(table.Table):
    Row = MyRow
    Column = MyColumn
    MaskedColumn = MyMaskedColumn
    TableColumns = MyTableColumns
    TableFormatter = MyTableFormatter


def test_simple_subclass():
    t = MyTable([[1, 2], [3, 4]])
    row = t[0]
    assert isinstance(row, MyRow)
    assert isinstance(t['col0'], MyColumn)
    assert isinstance(t.columns, MyTableColumns)
    assert isinstance(t.formatter, MyTableFormatter)

    t2 = MyTable(t)
    row = t2[0]
    assert isinstance(row, MyRow)
    assert str(row) == '(1, 3)'

    t3 = table.Table(t)
    row = t3[0]
    assert not isinstance(row, MyRow)
    assert str(row) != '(1, 3)'

    t = MyTable([[1, 2], [3, 4]], masked=True)
    row = t[0]
    assert isinstance(row, MyRow)
    assert str(row) == '(1, 3)'
    assert isinstance(t['col0'], MyMaskedColumn)
    assert isinstance(t.formatter, MyTableFormatter)


class ParamsRow(table.Row):
    """
    Row class that allows access to an arbitrary dict of parameters
    stored as a dict object in the ``params`` column.
    """

    def __getitem__(self, item):
        if item not in self.colnames:
            return super().__getitem__('params')[item]
        else:
            return super().__getitem__(item)

    def keys(self):
        out = [name for name in self.colnames if name != 'params']
        params = [key.lower() for key in sorted(self['params'])]
        return out + params

    def values(self):
        return [self[key] for key in self.keys()]


class ParamsTable(table.Table):
    Row = ParamsRow


def test_params_table():
    t = ParamsTable(names=['a', 'b', 'params'], dtype=['i', 'f', 'O'])
    t.add_row((1, 2.0, {'x': 1.5, 'y': 2.5}))
    t.add_row((2, 3.0, {'z': 'hello', 'id': 123123}))
    assert t['params'][0] == {'x': 1.5, 'y': 2.5}
    assert t[0]['params'] == {'x': 1.5, 'y': 2.5}
    assert t[0]['y'] == 2.5
    assert t[1]['id'] == 123123
    assert list(t[1].keys()) == ['a', 'b', 'id', 'z']
    assert list(t[1].values()) == [2, 3.0, 123123, 'hello']
