import os
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import cm

from config import *
from analytics import *

cmap = cm.get_cmap('Reds')

em_country = [
    'Mexico', 'Ghana', 'South Africa', 'Burma/Myanmar', 'Russia',
    'Albania', 'Egypt', 'Yemen', 'Brazil', 'El Salvador', 'Bangladesh',
    'Bolivia', 'Haiti', 'Honduras', 'Mali', 'Pakistan', 'Peru',
    'Senegal', 'Sudan', 'Vietnam', 'Afghanistan', 'Argentina',
    'Ethiopia', 'Kenya', 'Nigeria', 'Philippines', 'Tanzania',
    'Uganda', 'Venezuela', 'Benin', 'Burkina Faso', 'Cambodia',
    'Indonesia', 'Mozambique', 'Nicaragua', 'Niger', 'Zambia',
    'Zimbabwe', 'Guinea', 'Ivory Coast', 'Mauritania', 'Botswana',
    'Burundi', 'Cape Verde', 'Central African Republic', 'Costa Rica',
    'Ecuador', 'Guatemala', 'Iran', 'Iraq', 'Jordan', 'Lesotho',
    'Liberia', 'Malawi', 'Maldives', 'Mongolia', 'Morocco',
    'Sierra Leone', 'Tunisia', 'Turkey', 'Ukraine', 'Algeria',
    'Angola', 'Cameroon', 'Chad', 'Democratic Republic of the Congo',
    'Republic of the Congo', 'Djibouti', 'Dominican Republic',
    'Eritrea', 'Gabon', 'The Gambia', 'Georgia', 'Guinea-Bissau',
    'Jamaica', 'Kazakhstan', 'Kyrgyzstan', 'Laos', 'Madagascar',
    'Moldova', 'Rwanda', 'Somalia', 'Sri Lanka', 'Tajikistan', 'Togo',
    'Bosnia and Herzegovina', 'Cuba', 'Equatorial Guinea', 'Guyana',
    'North Macedonia', 'Paraguay', 'Romania', 'Sao Tome and Principe',
    'Serbia', 'Solomon Islands', 'Vanuatu'
]
EM_country_list = vdem[vdem['country'].isin(em_country)].code.unique()
EM_country_list_avail = [x for x in EM_country_list if x in polar]


var_label = {
    'polar': 'Polarization Level',
    'vio': 'Violence Level',
    'polar_inc': 'Polarization Increment',
    'vio_inc': 'Violence Increment',
    'debt_red': '5yrs Debt Reduction',
    'fiscal_surp': 'Fiscal Surplus',
}

rule_labels = labels = ['No Fiscal Rule',
                        'Budget Balance Rule', 'Debt Rule', 'Both Rule']


def get_legend_handlers(mkers):
    handles = []
    for design in mkers:
        if design[2]:
            handles.append(
                plt.Line2D([0], [0], marker=design[0], color=design[1],
                           markeredgecolor='black', markerfacecolor='none', linestyle='')
            )
        else:
            handles.append(
                plt.Line2D([0], [0], marker=design[0],
                           color=design[1], linestyle='')
            )
    return handles


def get_marker():
    # Marker for plotting
    mkers = [i for i in zip(
        ['.', '^', 's', 'D'],
        # ['black', cmap(0.65), cmap(0.8), 'purple'],
        ['black', cmap(0.5), cmap(0.65), cmap(0.8)],
        [True, False, False, False]
    )]
    handles = get_legend_handlers(mkers)

    return mkers, handles


def plot_reg(df, xlab, ylab, ax):
    x, y = df[xlab], df[ylab]
    m, b = np.polyfit(x, y, 1)
    x = np.array([-4, 6])
    ax.plot(x, m*x+b, color='k', linestyle='-.', alpha=0.5)


def plot_scatter_tb1(df, xlab, frule, ylab='debt_reduction', rot=False, figsize=(10, 5), annotate=True, **kwargs):
    if rot:
        ylab, xlab = xlab, ylab
        kwargs['xlabel'], kwargs['ylabel'] = kwargs.get(
            'ylabel', ylab), kwargs.get('xlabel', xlab)

    if ax := kwargs.get('ax'):
        pass
    else:
        fig, ax = plt.subplots(figsize=figsize)

    mkers = kwargs.get('markers', None)
    if mkers == True:
        mkers = [i for i in zip(
            ['o', '^', 's', 'D'],
            ['orange', cmap(0.65), cmap(0.8), 'purple'],
            [False, False, False, False]
        )]

    if legend := kwargs.get('legend'):
        pass

    if mkers is not None:
        grp = df.groupby(frule)
        for frule_val, group in grp:
            frule_val = int(frule_val)

            design = mkers[frule_val]
            if design[2]:
                ax.plot(group[xlab], group[ylab], marker=design[0], linestyle='', ms=8,
                        label=legend[frule_val], markeredgecolor=design[1], markerfacecolor='none')
            else:
                ax.plot(group[xlab], group[ylab], marker=design[0],
                        linestyle='', ms=7, label=legend[frule_val], color=design[1])
    else:
        ax.plot(df[xlab], df[ylab], marker='o', linestyle='', ms=8)

    jam_idx, jam_x, jam_y = None, None, None
    if annotate_lst := kwargs.get('annotate_lst'):
        for idx, row in df.iterrows():
            if idx in annotate_lst:
                siz = 12 if idx == 'JAM' else 8
                if idx == 'JAM':
                    jam_idx, jam_x, jam_y = idx, row[xlab], row[ylab]
                else:
                    ax.annotate(
                        idx, (row[xlab] + 0.07, row[ylab] - 0.07),  fontsize=siz)
    elif annotate:
        for idx, row in df.iterrows():
            siz = 12 if idx == 'JAM' else 10
            if idx == 'JAM':
                jam_idx, jam_x, jam_y = idx, row[xlab], row[ylab]
            else:
                ax.annotate(
                    idx, (row[xlab] + 0.07, row[ylab] - 0.07),  fontsize=siz)
    if jam_idx is not None:
        ax.annotate(jam_idx, (jam_x + 0.07, jam_y - 0.07),  fontsize=12)

    if xlabel := kwargs.get('xlabel'):
        ax.set_xlabel(xlabel, fontsize=12)
    if ylabel := kwargs.get('ylabel'):
        ax.set_ylabel(ylabel, fontsize=12)
    if title := kwargs.get('title'):
        ax.set_title(title, fontsize=14)
    if ylim := kwargs.get('ylim'):
        ax.set_ylim(ylim)
    else:
        a, b = ax.get_ylim()
        ax.set_ylim(a - 0.1*(b-a), b + 0.1*(b-a))

    if xlim := kwargs.get('xlim'):
        ax.set_xlim(xlim)

    hline_vals = kwargs.get('hline')
    if isinstance(hline_vals, pd.Series):
        for i, v in hline_vals.items():
            ax.axvline(v, color='k', linestyle='-.', lw=0.5, alpha=0.5)
            ax.text(v, ax.get_ylim()[0], i,
                    fontsize=10, ha='center', va='bottom')

    if reg := kwargs.get('reg'):
        ax = plot_reg(df, xlab, ylab, ax=ax)

    return ax


def plot_scatter_max_debt_reduction(plot_type = 'em_table1', x=None, y=None, mker=True, title=False, xlim1=None, xlim2=None, ):
    """Plots the scatter plot for given variable settings, 
    for the countries listed in the Table 1 of Arslanalp, Eichengreen, and Henry (2024)
    """

    # Default settings for each plot type
    default_settings = {
        'em_table1': {
            'x': 'polar',
            'y': 'debt_red',
            'mker': True,
            'xlim1': (-0.5, 4),
            'xlim2': (-0.5, 4)
        },
        'high_income': {
            'x': 'polar',
            'y': 'debt_red',
            'mker': True,
            'xlim1': (0, 3.5),
            'xlim2': (-0.3, 3)
        }
    }

    # Get the default settings for the given plot type
    settings = default_settings.get(plot_type, {})

    # Override defaults with provided arguments
    x = x if x is not None else settings['x']
    y = y if y is not None else settings['y']
    mker = mker if mker is not None else settings['mker']
    xlim1 = xlim1 if xlim1 is not None else settings['xlim1']
    xlim2 = xlim2 if xlim2 is not None else settings['xlim2']

    if mker:
        mkers, handles = get_marker()
    else:
        mkers = handles = None

    xlab = var_label[x]
    ylab = var_label[y]

    if plot_type == 'em_table1':
        if y == 'debt_red':
            res80, res00 = calc_table1_stats()
        elif y == 'fiscal_surp':
            res80, res00 = calc_table1_stats(fiscal_surp=True)
        else:
            raise ValueError('y must be either debt_red or fiscal_surp')
        res80_list = [res80, res00]
        xlim = [xlim1, xlim2]
        annotate_lst = None
    elif plot_type == 'high_income':
        raise ValueError('high_income plot type is not implemented')
        hi_lst_avail = [item for item in hi_lst if (
            item in polar) and (item in imf_debt)]
        if y == 'debt_red':
            res80, res00 = calc_table1_stats(lst=hi_lst_avail)
        elif y == 'fiscal_surp':
            hi_lst_avail_bal = [
                item for item in hi_lst_avail if item in imf_bal]
            res80, res00 = calc_table1_stats(
                lst=hi_lst_avail_bal, fiscal_surp=True)
        else:
            raise ValueError('y must be either debt_red or fiscal_surp')
        res80_list = [res80, res00]
        annotate_lst = ['USA', 'DEU', 'JPN', 'ITA']

    fig, ax = plt.subplots(figsize=(12, 5), ncols=2)
    frule = 'fiscal_rule'

    titles = ['1980 ~ 1999', '2000 ~ 2022']
    hlines = [
        polar[EM_country_list_avail].loc[:'1999'].stack().describe()[
            ['25%', '50%', '75%']],
        polar[EM_country_list_avail].loc['2000':].stack().describe()[
            ['25%', '50%', '75%']]
    ]

    for i in range(2):
        plot_scatter_tb1(
            res80_list[i], x,
            ylabel=ylab, xlabel=xlab,
            title=titles[i],
            xlim=xlim[i],
            frule=frule,
            ax=ax[i], legend=labels, markers=mkers,
            annotate_lst=annotate_lst,
            hline=hlines[i]
        )

    if title:
        fig.suptitle(f'{xlab} vs. {ylab}', fontsize=12)

    if mker:
        fig.legend(handles, labels, loc='lower center', ncol=4,
                   fontsize=10, bbox_to_anchor=(0.5, -0.05))

    fig.tight_layout()
    return fig, ax


def plot_scatter_all_sample(figsize = (12, 5)):
    df = calc_all_sample_stats()

    fig, ax = plt.subplots(1, 2, figsize=figsize)
    sns.regplot(
        x='polar_change', y='debt_change', data=df,
        ci=False, line_kws={"color": "red"}, scatter_kws={"color": "#1f77b4"},
        ax = ax[0]
    )
    ax[0].set_xlabel("5-Yrs Polarization Change", fontsize=12)
    ax[0].set_ylabel("5-Yrs Debt Change", fontsize=12)

    sns.regplot(
        x='polar_level', y='fiscal_surp', data=df,
        ci=False, line_kws={"color": "red"}, scatter_kws={"color": "#1f77b4"},
        ax = ax[1]
    )
    ax[1].set_xlabel("Polarization Level", fontsize=12)
    ax[1].set_ylabel("Fiscal Surplus", fontsize=12)
    ax[1].set_ylim(-100, 100)

    return fig, ax


def plot_time_series_variables(code):
    df = pd.concat([imf_debt[code], polar[code], vio[code]], axis=1).dropna()
    df.columns = ['Government Gross Debt', 'Polarization', 'Violence']

    ax = df['Government Gross Debt'].plot(ylabel='% of GDP', legend=True, figsize=(10, 5))
    ax2 = df[['Polarization', 'Violence']].plot(
        secondary_y=True, ylim=(0, 4.5), ax=ax)

    ax.set_xlabel('Year', fontsize=12)
    ax.set_ylabel('% of GDP', fontsize=12)
    ax2.set_ylabel('Level', fontsize=12)
    
    fig = ax.figure
    return fig, ax
