import numpy as np
import json
from jinja2 import Template

event_dictionary_pyring = {
                    "S191109d": "GW191109\_010717", 
                    "S191222n": "GW191222\_033537", 
                    "S200129m": "GW200129\_065458", 
                    "S200224ca": "GW200224\_222234", 
                    "S200311bg": "GW200311\_115853",
}

event_list_pyring = list(event_dictionary_pyring.keys())
event_list = sorted(set(event_list_pyring))

data_file_path_template_pyring = "{event}/rin_{event}_pyring_{param}.dat.gz"
data_file_path_template_IMR = "{event}/rin_{event}_IMR_{param}.dat.gz"

# Parameters to be tabulated
params_pyring = ["f_t_0", "tau_t_0"]
params_IMR = ["freq_220", "tau_220"]
params = params_pyring + params_IMR

sample_dict = {}
# Load the posterior sample files produced by pull_data_pyring.py
for param in params_pyring:
    sample_dict[param] = {}
    for event in event_dictionary_pyring.keys():
        sample_dict[param][event] = np.loadtxt(data_file_path_template_pyring.format(event=event, param=param))
        
# Load the posterior sample files produced by pull_data_IMR.py
for param in params_IMR:
    sample_dict[param] = {}
    for event in event_list:
        # FIXME: Also get the posterior samples for O1O2 events
        try:
            sample_dict[param][event] = np.loadtxt(data_file_path_template_IMR.format(event=event, param=param))
        except:
            pass

# Prepare the dictionary that should be fed to jinja2
jinja2_data_dict = []
for idx, event in enumerate(event_list):
    jinja2_data_dict.append({"catalog_id": event_dictionary_pyring[event]})

# Compute summary statistics: median, 90% CI limits
for idx, event in enumerate(event_list):
    for param in params:
        try:
            jinja2_data_dict[idx][param+"_median"] = np.median(sample_dict[param][event])
            # 5th-percentile relative to the median
            jinja2_data_dict[idx][param+"_lower_limit"] = np.median(sample_dict[param][event]) - np.percentile(sample_dict[param][event], 5)
            # 95th-percentile relative to the median
            jinja2_data_dict[idx][param+"_upper_limit"] = np.percentile(sample_dict[param][event], 95) - np.median(sample_dict[param][event])
        except:
            # That parameter does not exist in the dict, returning np.nan
            jinja2_data_dict[idx][param+"_median"] = None
            # 5th-percentile relative to the median
            jinja2_data_dict[idx][param+"_lower_limit"] = None
            # 95th-percentile relative to the median
            jinja2_data_dict[idx][param+"_upper_limit"] = None 

# Format the output, and show a dash for nan
for idx, event in enumerate(event_list):
    # IMR
    try:
        jinja2_data_dict[idx]["freq_IMR"] = "${0:.0f}^{{+{1:.0f}}}_{{-{2:.0f}}}$".format(jinja2_data_dict[idx]["freq_220_median"], jinja2_data_dict[idx]["freq_220_upper_limit"], jinja2_data_dict[idx]["freq_220_lower_limit"])
        jinja2_data_dict[idx]["tau_IMR"] = "${0:.1f}^{{+{1:.1f}}}_{{-{2:.1f}}}$".format(1e3*jinja2_data_dict[idx]["tau_220_median"], 1e3*jinja2_data_dict[idx]["tau_220_upper_limit"], 1e3*jinja2_data_dict[idx]["tau_220_lower_limit"])
    except:
        jinja2_data_dict[idx]["freq_IMR"] = "$-$"
        jinja2_data_dict[idx]["tau_IMR"] = "$-$"
    # pyRing
    try:
        jinja2_data_dict[idx]["freq_pyRing"] = "${0:.0f}^{{+{1:.0f}}}_{{-{2:.0f}}}$".format(jinja2_data_dict[idx]["f_t_0_median"], jinja2_data_dict[idx]["f_t_0_upper_limit"], jinja2_data_dict[idx]["f_t_0_lower_limit"])
        jinja2_data_dict[idx]["tau_pyRing"] = "${0:.1f}^{{+{1:.1f}}}_{{-{2:.1f}}}$".format(1e3*jinja2_data_dict[idx]["tau_t_0_median"], 1e3*jinja2_data_dict[idx]["tau_t_0_upper_limit"], 1e3*jinja2_data_dict[idx]["tau_t_0_lower_limit"])
    except:
        jinja2_data_dict[idx]["freq_pyRing"] = "$-$"
        jinja2_data_dict[idx]["tau_pyRing"] = "$-$"

# Jinja2 template for the table in LaTeX
LaTex_table_jinja_template = r"""
\begin{tabular}{lccccc}
\toprule
Event & \multicolumn{2}{c}{Redshifted} & \hphantom{X} & \multicolumn{2}{c}{Redshifted} \\
& \multicolumn{2}{c}{frequency [Hz]} & \hphantom{X} & \multicolumn{2}{c}{damping time [ms]} \\[0.075cm]
\cline{2-3}
\cline{4-6}
& IMR & DS & \hphantom{X} & IMR & DS \\
\midrule
{% for event in  jinja2_data_dict %}
{{ event.catalog_id }} &
{{ event.freq_IMR }} &
{{ event.freq_pyRing }}  &
\hphantom{X} &
{{ event.tau_IMR }} &
{{ event.tau_pyRing }}\\[0.075cm]
{% endfor %}
\bottomrule
\end{tabular}
"""

Template(LaTex_table_jinja_template).stream(jinja2_data_dict=jinja2_data_dict).dump("tab_rin_freq_tau_results.tex") 
