#%%
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
from config import *

from parse_data import *

pol_name = 'v2cacamps_mean'
vio_name = 'v2caviol_osp'


table1_year_map = dict([
    ('PAN','2010'), ('MUS','2008'), ('PHL','2008'), ('PER','2008'), ('ARM','2007'), 
    ('EGY', '2008'), ('PRY', '2007'), ('TUR', '2007'), ('JOR', '2008'), ('BGR', '2005'), 
    ('GEO', '2007'), ('JAM', '2018'), ('IDN', '2005'), ('LBN', '2011')
])
factor_lst = ['polar', 'vio', 'bbr', 'dr']
table1_country_list = list(table1_year_map)
table1_country_list_ext = table1_country_list + ['IRL', 'ISL', 'BRB']

min_year_fiscal_rule = pd.to_datetime('1985', format = '%Y')


vdem, code_map, polar, vio = get_vdem_data()
imf_debt, imf_bal = get_imf_data()
frule, bbr, dr = get_frule_data()

# oil_lst, poor_lst = get_default_list()
# hi_lst = get_high_income_list()

df_debt_reduction = imf_debt.ffill().diff(5) * -1
df_debt_reduction.index = pd.to_datetime(df_debt_reduction.index, format = '%Y')



def construct_5y_reduction(i, j, delta = False):    
    d = pd.to_datetime(j, format = '%Y')
    d2 = d - pd.DateOffset(years = 5)
    
    if d2 >= min_year_fiscal_rule:
        bbr_val = bbr[i].loc[d2] if i in bbr.columns else 0
        dr_val = dr[i].loc[d2] if i in dr.columns else 0
    else:
        bbr_val = 0
        dr_val = 0
        
    if delta:
        d3 = d - pd.DateOffset(years = 10)
        return [df_debt_reduction[i].loc[d], polar[i].loc[d2] - polar[i].loc[d3], 
                vio[i].loc[d2] - vio[i].loc[d3], bbr_val, dr_val]
    else:
        return [df_debt_reduction[i].loc[d], polar[i].loc[d2], vio[i].loc[d2], bbr_val, dr_val]


def construct_5y_fiscal_surplus(i, j, delta = False):
    d = pd.to_datetime(j, format = '%Y')
    d2 = d - pd.DateOffset(years = 5)
    
    if d2 >= min_year_fiscal_rule:
        bbr_val = bbr[i].loc[d2] if i in bbr.columns else 0
        dr_val = dr[i].loc[d2] if i in dr.columns else 0
    else:
        bbr_val = 0
        dr_val = 0
    if delta:
        return [imf_bal[i].loc[d] - imf_bal[i].loc[d2], polar[i].loc[d2], vio[i].loc[d2], bbr_val, dr_val]
    else:
        return [imf_bal[i].loc[d2], polar[i].loc[d2], vio[i].loc[d2], bbr_val, dr_val]


def calc_table1_stats(lst = table1_country_list_ext, fiscal_surp = False, delta = False):
    """Returns relevant statistics for the countries in the Table 1 list.
    """

    # Find the year of maximum debt reduction for the countries, in 1980 ~ 2000 period.
    # df_debt_reduction is aligned on the last year of the 5-year period,
    # so it is truncated at 2004
    cnt_lst80 = df_debt_reduction.loc[:"2004"].idxmin()[lst].dropna()
    res80 = []
    for i, j in cnt_lst80.items(): 
        if not fiscal_surp: 
            res80.append(construct_5y_reduction(i, j, delta = delta))
        else:
            res80.append(construct_5y_fiscal_surplus(i, j, delta = delta))

    res80 = pd.DataFrame(res80, index = [i for i in cnt_lst80.index])
    res80.columns = ['debt_reduction'] + factor_lst
    res80['fiscal_rule'] = res80['bbr'] + 2*res80['dr']

    
    # 2000 ~ 2022 period. Datapoints corresponds to the starting year
    cnt_lst00 = df_debt_reduction.loc["2005":"2022"].idxmax()[lst].dropna()
    res = []
    name = []
    for i in lst:
        if fiscal_surp and i not in imf_bal:
            continue
        
        if i in table1_country_list:
            j = table1_year_map[i] # use the year from Table 1 if available
        else:
            j = cnt_lst00.loc[i] # use the year of maximum debt reduction if not available
        name.append(i)
        
        if not fiscal_surp:
            res.append(construct_5y_reduction(i, j, delta = delta))
        else:
            res.append(construct_5y_fiscal_surplus(i, j, delta = delta))

    res = pd.DataFrame(res, index = name)
    res.columns = ['debt_reduction'] + factor_lst
    res['fiscal_rule'] = res['bbr'] + 2*res['dr']
    return res80, res


def calc_all_sample_stats():
    df = pd.concat([
        imf_debt.ffill().diff(5).shift(-5).stack(),
        imf_bal.stack(),
        polar.diff(5).stack(),
        vio.diff(5).stack(),
        polar.stack(),
        vio.stack(),
    ], axis=1)
    
    df.columns = [
        'debt_change', 'fiscal_surp', 
        'polar_change', 'vio_change', 
        'polar_level', 'vio_level'
    ]
    
    # df[['debt_change', 'polar_change']].dropna().index.levels[1].nunique()
    # df[['fiscal_surp', 'polar_level']].dropna().index.levels[1].nunique()
    return df

