import pytest

from ..interpolation import interpolate_max

y0 = [9.135017-2.8185585j, 9.995214-1.1222992j, 10.682851+0.8188147j,
      10.645139+3.0268786j, 9.713133+5.5589147j, 7.9043484+7.9039335j,
      5.511646+9.333084j, 2.905198+9.715742j, 0.5302934+9.544538j]
y1 = [-14.04298458-12.42025719j, -13.54512279-13.9975707j,
      -12.81987146-15.65632332j, -11.83145220-17.25894236j,
      -10.67978351-18.78930327j, -9.38321345-20.26279014j,
      -7.93078349-21.75487168j, -6.20864157-23.23521056j,
      -4.18656412-24.61039025j, -1.91352079-25.76366939j,
      0.50485300-26.81349584j, 3.24674313-27.88200749j,
      6.61982798-28.75234467j, 10.67117579-29.02191297j,
      15.20752937-28.32539646j, 19.86163875-26.31802974j,
      23.95519327-22.82715585j, 26.67973660-18.22160184j,
      27.69204702-13.56075191j, 27.61097719-9.60342867j,
      27.15105065-6.36397975j, 26.54148343-3.43587179j,
      25.59970437-0.77676502j, 24.42081565+1.48165824j,
      23.21572597+3.26870707j, 22.16738214+4.89881504j,
      21.06343508+6.52136893j, 19.79261510+7.95655696j,
      18.53300183+9.1869077j, 17.27943443+10.35891335j,
      15.85961229+11.44834434j, 14.29135572+12.25702765j,
      12.80306731+12.77700535j]
y2 = [24.33546507-4.37237844j, 25.21430734-2.94850093j,
      26.24653016-1.46698428j, 27.43392439+0.31641935j,
      28.55556797+2.4344684j, 29.53970968+4.9329713j,
      30.24829181+7.68385687j, 30.74244848+10.63488387j,
      31.05555847+13.79768684j, 31.16910840+17.35620378j,
      30.81906320+21.42675033j, 29.71256983+25.80061299j,
      27.85940204+30.3305615j, 25.15948657+34.9632285j,
      21.29005445+39.71231569j, 15.69023426+44.03212162j,
      8.40941881+46.68943372j, 0.65145839+46.75721585j,
      -6.03343803+44.61857557j, -11.05937409+41.54119126j,
      -14.90679982+38.44188372j, -18.17229175+35.31228839j,
      -20.85054167+32.08798444j, -22.97315194+28.83554109j,
      -24.64066997+25.66341162j, -25.94581892+22.51742823j,
      -26.82315799+19.46526862j, -27.37637665+16.55731713j,
      -27.69785243+13.83218299j, -27.85474897+11.17681254j,
      -27.73414707+8.52204497j, -27.23875261+6.08054244j,
      -26.58717550+3.98060113j]
y3 = [25.38672374-10.46990646j, 26.98036717-8.74105323j,
      28.42577358-6.75887934j, 29.67590345-4.50925463j,
      30.62545698-2.16923242j, 31.44275377+0.18575934j,
      32.27276923+2.57731966j, 33.19687126+5.33734012j,
      33.89400960+8.58856349j, 34.13857476+12.22779037j,
      33.87819938+16.02691928j, 33.23756141+20.00079092j,
      32.15025759+24.27606864j, 30.38635264+28.89883078j,
      27.61488144+33.82233139j, 23.36306372+38.76765413j,
      17.29380596+42.88182178j, 9.89246651+44.9882819j,
      2.54059732+44.71650049j, -3.67718110+42.65873447j,
      -8.48361774+40.01375385j, -12.50375338+37.27989278j,
      -16.04996471+34.31356422j, -18.98265862+31.01719682j,
      -21.19277639+27.6541388j, -22.88668959+24.44207892j,
      -24.24895445+21.30383581j, -25.21392241+18.14916359j,
      -25.71011223+15.13051498j, -25.90832412+12.43072643j,
      -26.04673459+9.92719558j, -26.08115387+7.46027909j,
      -25.90777937+5.03328563j]
y4 = [0, 1.23412285-1.561222751j, -2.06157986+0.782987141j,
      -0.52811449+0.478591891j, 0.42450302+2.068777141j]


@pytest.mark.parametrize(
    'y,window,method,expected_tmax,expected_ymax',
    [[y0, 4, 'lanczos',
      4.508785386872524, 8.906417801973994+6.839630391972431j],
     [y0, 4, 'catmull-rom',
      4.606828337716582, 8.69890509129855+7.057010747638154j],
     [y0, 4, 'quadratic-fit',
      4.047784375601079, 9.635253882349538+5.677856407731738j],
     [y0, 4, 'nearest-neighbor',
      4, 9.713133+5.5589147j],
     [y1, 16, 'lanczos',
      15.643137901686925, 22.607356086592343-24.241989069241296j],
     [y2, 16, 'lanczos',
      16 + 0.0080079263861989602, 8.346533081471625+46.70076293771792j],
     [y3, 16, 'lanczos',
      16 + 0.35219617295969, 14.769739126762206+43.89473896632356j],
     [[0, 1, 8, 27, 256], 2, 'catmull-rom', 3, 27],
     [[256, 27, 8, 1, 0], 2, 'catmull-rom', 1, 27],
     [y4, 2, 'catmull-rom', 2.108106552865, -2.10041938439+0.85409051094j],
     [[0, 1, 8, 27, 256], 2, 'catmull-rom-amp-phase', 3, 27],
     [[256, 27, 8, 1, 0], 2, 'catmull-rom-amp-phase', 1, 27],
     [y4, 2, 'catmull-rom-amp-phase',
      2.0, -2.06157986+0.782987141j]])
def test_interpolate_max(y, window, method, expected_tmax, expected_ymax):
    tmax, ymax = interpolate_max((len(y) - 1) // 2, y, window, method)
    assert tmax == pytest.approx(expected_tmax)
    assert ymax == pytest.approx(expected_ymax)
