import numpy as np
import json
from jinja2 import Template

event_dictionary = {
                    "S191109d": "GW191109\_010717", 
                    "S191222n": "GW191222\_033537", 
                    "S200129m": "GW200129\_065458", 
                    "S200224ca": "GW200224\_222234", 
                    "S200311bg": "GW200311\_115853",
}

data_file_path_template = "{event}/rin_{event}_{param}.dat.gz"

#FIXME: MMRDNP names refer to m1 and m2, anyway these values are not used below 
# Parameters to be tabulated
params = ["Kerr_220_Mf", "Kerr_220_af", "Kerr_221_Mf", "Kerr_221_af", "Kerr_HM_Mf", "Kerr_HM_af", "IMR_Mf", "IMR_af"]

# Load the posterior sample files
sample_dict = {}
for param in params:
    sample_dict[param] = {}
    for event in event_dictionary.keys():
        if "IMR" in param:
            # IMR results
            # FIXME: Load O1O2 results as well
            try:
                sample_dict[param][event] = np.loadtxt(data_file_path_template.format(event=event, param=param))
            except:
                sample_dict[param][event] = [0.0]
        elif "Kerr" in param:
            # pyring results:
            sample_dict[param][event] = np.loadtxt(data_file_path_template.format(event=event, param="pyring_" + param))
        

# Load the json files
with open("rin_pyring_log10_BFs_HM_vs_noHM.json", "r") as f:
    log10_BFs_HM_vs_noHM = json.load(f)
with open("rin_pyring_log10_BFs_OT_vs_noOT.json", "r") as f:
    log10_BFs_OT_vs_noOT = json.load(f)
with open("rin_pyring_log10_BFs_TIGER_modGR_vs_GR.json", "r") as f:
    log10_BFs_TIGER_modGR_vs_GR = json.load(f)

# Prepare the dictionary that should be fed to jinja2
jinja2_data_dict = []
for idx, event in enumerate(event_dictionary.keys()):
    jinja2_data_dict.append({"catalog_id": event_dictionary[event]})

# Compute summary statistics: median, 90% CI limits
for idx, event in enumerate(event_dictionary.keys()):
    for param in params:
        # Median
        jinja2_data_dict[idx][param+"_median"] = np.median(sample_dict[param][event])
        # 5th-percentile relative to the median
        jinja2_data_dict[idx][param+"_lower_limit"] = np.median(sample_dict[param][event]) - np.percentile(sample_dict[param][event], 5)
        # 95th-percentile relative to the median
        jinja2_data_dict[idx][param+"_upper_limit"] = np.percentile(sample_dict[param][event], 95) - np.median(sample_dict[param][event])

# Load the Bayes factor
for idx, event in enumerate(event_dictionary.keys()):
    print(idx,event,log10_BFs_HM_vs_noHM)
    jinja2_data_dict[idx]["log10_BF_HM_vs_noHM"] = log10_BFs_HM_vs_noHM[event]
    jinja2_data_dict[idx]["log10_BF_OT_vs_noOT"] = log10_BFs_OT_vs_noOT[event]
    jinja2_data_dict[idx]["log10_BF_TIGER_modGR_vs_GR"] = log10_BFs_TIGER_modGR_vs_GR[event]

#FIXME: 
# Jinja2 template for the table in LaTeX
LaTex_table_jinja_template = r"""
\begin{tabular}{lllllllllllrrrr}
\toprule
Event & \multicolumn{4}{c}{Redshifted final mass} & \hphantom{X} & \multicolumn{4}{c}{Final spin} & \hphantom{X} & \multicolumn{1}{c}{Higher} & \hphantom{X} & \multicolumn{2}{c}{Overtones} \\
& \multicolumn{4}{c}{$(1+z)M_\mathrm{f} \; [M_{\odot}]$} & \hphantom{X} & \multicolumn{4}{c}{$\chi_{\mathrm{f}}$} & \hphantom{X} & \multicolumn{1}{c}{modes} & \hphantom{X} &  \multicolumn{2}{c}{} \\[0.075cm]
\cline{2-5}
\cline{7-10}
\cline{12-12}
\cline{14-15}
& IMR & $\mathrm{Kerr_{220}}$ & $\mathrm{Kerr_{221}}$ & $\mathrm{Kerr_{HM}}$ & \hphantom{X} & IMR & $\mathrm{Kerr_{220}}$ & $\mathrm{Kerr_{221}}$ & $\mathrm{Kerr_{HM}}$ & \hphantom{X} &  \multicolumn{1}{c}{$\log_{10} \mathcal{B}^{\rm HM}_{\rm 220}$} & \hphantom{X} & \multicolumn{1}{c}{$\log_{10} \mathcal{B}^{\rm 221}_{\rm 220}$} & \multicolumn{1}{c}{$\log_{10} \mathcal{O}^{\rm modGR}_{\rm GR}$} \\
\midrule
{% for event in  jinja2_data_dict %}
{{ event.catalog_id }} &
$ {{ "{:.1f}".format(event.IMR_Mf_median) }}^{+ {{ "{:.1f}".format(event.IMR_Mf_upper_limit) }} }_{- {{ "{:.1f}".format(event.IMR_Mf_lower_limit) }} } $ &
$ {{ "{:.1f}".format(event.Kerr_220_Mf_median) }}^{+ {{ "{:.1f}".format(event.Kerr_220_Mf_upper_limit) }} }_{- {{ "{:.1f}".format(event.Kerr_220_Mf_lower_limit) }} } $ &
$ {{ "{:.1f}".format(event.Kerr_221_Mf_median) }}^{+ {{ "{:.1f}".format(event.Kerr_221_Mf_upper_limit) }} }_{- {{ "{:.1f}".format(event.Kerr_221_Mf_lower_limit) }} } $ &
$ {{ "{:.1f}".format(event.Kerr_HM_Mf_median) }}^{+ {{ "{:.1f}".format(event.Kerr_HM_Mf_upper_limit) }} }_{- {{ "{:.1f}".format(event.Kerr_HM_Mf_lower_limit) }} } $ &
\hphantom{X} &
$ {{ "{:.2f}".format(event.IMR_af_median) }}^{+ {{ "{:.2f}".format(event.IMR_af_upper_limit) }} }_{- {{ "{:.2f}".format(event.IMR_af_lower_limit) }} } $ &
$ {{ "{:.2f}".format(event.Kerr_220_af_median) }}^{+ {{ "{:.2f}".format(event.Kerr_220_af_upper_limit) }} }_{- {{ "{:.2f}".format(event.Kerr_220_af_lower_limit) }} } $ &
$ {{ "{:.2f}".format(event.Kerr_221_af_median) }}^{+ {{ "{:.2f}".format(event.Kerr_221_af_upper_limit) }} }_{- {{ "{:.2f}".format(event.Kerr_221_af_lower_limit) }} } $ &
$ {{ "{:.2f}".format(event.Kerr_HM_af_median) }}^{+ {{ "{:.2f}".format(event.Kerr_HM_af_upper_limit) }} }_{- {{ "{:.2f}".format(event.Kerr_HM_af_lower_limit) }} } $ &
\hphantom{X} & 
$ {{ "{:.2f}".format(event.log10_BF_HM_vs_noHM) }} $ &
\hphantom{X} &
$ {{ "{:.2f}".format(event.log10_BF_OT_vs_noOT) }} $ &
$ {{ "{:.2f}".format(event.log10_BF_TIGER_modGR_vs_GR) }} $ \\[0.075cm]
{% endfor %}
\bottomrule
\end{tabular}
"""

Template(LaTex_table_jinja_template).stream(jinja2_data_dict=jinja2_data_dict).dump("tab_rin_pyring_results.tex")
