%% Step 0: Housekeeping %%
close all
clear
clc
current_path = pwd;
save_path_tables = '/Tables/';
save_path_fig = '/Figures/';


%% SETUP ELASTICITY AND OUTPUT GRIDS
grid_step = 0.1;

tau_k = [0:grid_step:.40];
tau_ell = [0:grid_step:.40];

N_k = length(tau_k);
N_ell = length(tau_ell);

grid_tau_k = kron(tau_k, ones(1, length(tau_ell)));
grid_tau_ell = kron(ones(1, length(tau_k)), tau_ell);


% store outs
out_ramsey_st = cell(length(tau_ell) * length(tau_k), 1);
out_theta_dist_st = out_ramsey_st;
out_mkt_st = out_ramsey_st;

out_exitflags = zeros(size(out_theta_dist_st));
out_resids = out_ramsey_st;
% store indicator maxit reached
out_maxit_ram = zeros(size(out_theta_dist_st));
out_maxit_the = zeros(size(out_theta_dist_st));
out_maxit_mkt = zeros(size(out_theta_dist_st));

% store change in emp
out_dl_ram = zeros(size(out_theta_dist_st));
out_dl_the = zeros(size(out_theta_dist_st));

% store change in emp
out_dshare_ram = zeros(size(out_theta_dist_st));
out_dshare_the = zeros(size(out_theta_dist_st));

%% Step 1: parametrization of basic parameters%%

main_par.N_tasks = 10000;
main_par.tau_current_labor = 0.255; % average labor tax
main_par.tau_current_capital = 0.1; % average capital wedge
main_par.tau_bar = main_par.tau_current_labor; %lower bound on labor taxation
main_par.delta = 0.055;
params = setParams();
% define elasticities and tau loci

% threshold_corrected=1/(1+(1/elast_k)*(elast_ell+varrho)/(tau_andrea_labor/(1-tau_andrea_labor)+varrho));

locus_tau_k_ell = @(tau_ell, rho) ...
    1./(1+(1/params.elast_k).*(params.elast_ell + rho)./...
    (tau_ell./(1 - tau_ell) + rho));
tau_k_ell_rho_low = @(tau_ell) locus_tau_k_ell(tau_ell, 0);
tau_k_ell_rho_high = @(tau_ell) locus_tau_k_ell(tau_ell, 0.15);

locus_tau_ell = @(e_k, tau_k, e_ell, rho) ...
    (tau_k .* e_k ./ e_ell - rho)./ (1 - rho);
tau_ell_k_rho_low = @(tau_k) locus_tau_ell(1/params.elast_k, tau_k, ...
    1/params.elast_ell, 0);
tau_ell_k_rho_high = @(tau_k) locus_tau_ell(1/params.elast_k, tau_k, ...
    1/params.elast_ell, 0.15);

% SOLVER FOR A and zeta to match labor share and el. subs

A_zeta_E_guess = [   20.976617907612962; ...
   2.034427788124405; ...
   0.377961836037824];

labor_share_target = .56;
k_ratio_target = 3.32;
K_share_target = .3;
el_subs_target = 1.35;
weight_labor = .66; % weigh deviation from labor and k share
weight_share = .5; % weigh deviation from k - share over total shares weight


fun = @(a) target_deviations_K_share(exp(a), labor_share_target,  el_subs_target, ...
    K_share_target, main_par.tau_current_labor, main_par.tau_current_capital, main_par,...
    weight_labor, weight_share);

opts = optimset('Display', 'off');

%% SOLUTION PART
[A_zeta_E_sol, fval, exitflag] = fminsearch(fun, log(A_zeta_E_guess), opts);
% [A_zeta_sol, fval, exitflag] = fsolve(fun, log(A_zeta_guess), opts)
A_zeta_E_sol = exp(A_zeta_E_sol);

if exitflag ~= 1
    disp('fmisearch failed')
end

main_par.A = A_zeta_E_sol(1);
main_par.zeta = A_zeta_E_sol(2);
main_par.E = A_zeta_E_sol(3);

params = setParams(main_par);
out_mkt = market_eq_p(main_par.tau_current_labor,main_par.tau_current_capital, params);

disp('Baseline shares:')
disp(['Labor share: ' num2str(out_mkt.labor_share)])
disp(['Net capital share: ' num2str(out_mkt.net_k_share)])
disp(['Capital/output ratio: ' num2str(out_mkt.k_ratio)])

% set labor and capital guess
params.guess_l = out_mkt.ell;
params.guess_k = out_mkt.k;

% set targets to baseline for labor share, net k share and k ratio
weight_labor_robustness = 1;
weight_share_robustness = .1;
labor_share_target_robustness = out_mkt.labor_share;
el_subs_target_robustness = out_mkt.el_subs_rw;
K_share_target_robustness = out_mkt.net_k_share;


%% CHANGE ONLY A to target labor share, fix zeta to benchmark


for j_el_pair = 1: length(grid_tau_k)
    main_par = struct();
    main_par.N_tasks = 10000;
    main_par.tau_current_labor = grid_tau_ell(j_el_pair);
    main_par.tau_current_capital = grid_tau_k(j_el_pair);
    main_par.tau_bar = main_par.tau_current_labor; 

    main_par.delta = 0.055;
    
    fun = @(a) target_deviations_K_share(exp(a), labor_share_target_robustness,  el_subs_target_robustness, ...
    K_share_target_robustness, main_par.tau_current_labor, main_par.tau_current_capital, main_par,...
    weight_labor_robustness, weight_share_robustness);
    opts = optimset('Display', 'on', 'MaxFunEvals', 500);
    [A_zeta_sol, fval, exitflag] = fminsearch(fun, log(A_zeta_E_sol), opts);
    A_zeta_sol = exp(A_zeta_sol);
   
    out_exitflags(j_el_pair) = exitflag;
    out_resids{j_el_pair} = fval;
    if exitflag ~= 1
        disp('fminsearch failed')
    end
    main_par.A = A_zeta_sol(1);
    main_par.zeta = A_zeta_sol(2);
    main_par.E = A_zeta_sol(3);
    
    params = setParams(main_par);
    out_mkt_2020 = market_eq_p(main_par.tau_current_labor,...
        main_par.tau_current_capital, params);
    g = out_mkt_2020.revenue; % target government spending
    
    
    params.guess_k = out_mkt_2020.k;
    params.guess_l = out_mkt_2020.ell;
    params.index_guess = out_mkt_2020.theta_index;
    out_planner =  ramsey_unconstrained_p(g, params);

% 
%     out_ramsey_current_tax = ...
%         ramsey_theta_p(main_par.tau_current_labor,...
%         main_par.tau_current_capital,params);
    
%     out_mkt_st{j_el_pair} = out_mkt_2020;
%     out_theta_dist_st{j_el_pair} = out_ramsey_current_tax;
%     out_ramsey_st{j_el_pair} = out_planner;

    % store indicator maxit reached
    out_maxit_ram(j_el_pair) = (length(out_planner.index_series) - params.maxIter - 1) == 0 ;
    out_maxit_mkt(j_el_pair) = (length(out_mkt_2020.index_series) - params.maxIter - 1) == 0;
    
    % change in employment
    out_dl_ram(j_el_pair) = (out_planner.ell - out_mkt_2020.ell)/out_mkt_2020.ell;
%     out_dl_the(j_el_pair) = (out_ramsey_current_tax.ell - out_mkt_2020.ell)/out_mkt_2020.ell;
    out_dshare_ram(j_el_pair) = (out_planner.labor_share - out_mkt_2020.labor_share)/out_mkt_2020.labor_share;
%     out_dshare_the(j_el_pair) = (out_ramsey_current_tax.labor_share - out_mkt_2020.labor_share)/out_mkt_2020.labor_share;
    disp(['completed ' num2str(j_el_pair)])
%     
%     round(j_el_pair/length(grid_tau_k), 2, 'significant')
end


save('taxes_laborshare_opt_3D.mat')


plot_data_ls = reshape(out_dl_ram, [ N_ell N_k]);

figure()
mesh(tau_k, tau_ell, plot_data_ls)

title('employment changes')
zlim([-5e-3 Inf])


%% FIGURE WITH EMPLOYMENT CHANGES
tau_l_vec = 0:0.01:.4;



locus_tau_k_ell = @(tau_ell, rho) ...
    1./(1+(1/params.elast_k).*(params.elast_ell + rho)./...
    (tau_ell./(1 - tau_ell) + rho));
tau_k_ell_rho_low = @(tau_ell) locus_tau_k_ell(tau_ell, 0);
tau_k_ell_rho_high = @(tau_ell) locus_tau_k_ell(tau_ell, 0.15);

tau_k_vec_low = tau_k_ell_rho_low(tau_l_vec);
tau_k_vec_high = tau_k_ell_rho_high(tau_l_vec);


fig_name  = 'tau_emp_robust_noshade';

load('taxes_laborshare_opt_3D.mat')
plot_data_ls = reshape(out_dl_ram, [ N_ell N_k]);
% define a grid for theta around the market equilibrium
fig_width = 1600;               % Window Horizontal Size
fig_heigth = 900;               % Window Vertical Size
fig_posx = 100;                 % Window position (Lower left corner)
fig_posy = 100;                 % Window position (Lower left corner)

tau_k_vec = 0:0.01:.4;
tau_ell_vec_rho_high =  tau_ell_k_rho_high(tau_k_vec);
tau_ell_vec_rho_low =  tau_ell_k_rho_low(tau_k_vec);
baseLine = .4; %top of shading
levels = 0:0.02:10;

tau_k_to_label = [tau_k_vec(end) tau_k_vec(end)];
tau_ell_to_label  = [tau_ell_vec_rho_low(end - 1) tau_ell_vec_rho_high(end - 1)];
tau_labels = {'$\rho = 0$','$\rho = 0.15$'};

fontSize = 20;
fontSizeAx = 36;

hh = figure('Position', [fig_posx fig_posy fig_width fig_heigth]);
hold on
set(gca,'TickLabelInterpreter', 'latex');    


%# Add to the plot
h2 = fill([0 0 tau_k_ell_rho_low(tau_l_vec(1)) tau_k_ell_rho_low(tau_l_vec(1)) tau_k_ell_rho_low(tau_l_vec),...
    tau_k_ell_rho_low(tau_l_vec(end))] ,...        %# Plot the first filled polygon
          [baseLine 0 0 baseLine tau_l_vec baseLine],...
          'k','EdgeColor','none');
set(h2,'FaceAlpha',0.10);

h3 = fill([tau_k_ell_rho_high(tau_l_vec) tau_k_ell_rho_low(tau_l_vec(end:-1:1))] ,...        %# Plot the first filled polygon
          [tau_l_vec tau_l_vec(end:-1:1)],...
          'k','EdgeColor','none');
set(h3,'FaceAlpha',0.15);

[C,h] = contour(tau_k, tau_ell, plot_data_ls, levels, 'LineWidth', 3);
clabel(C,h, 'fontsize',fontSizeAx, 'interpreter', 'latex');

a = get(gca,'XTickLabel');
set(gca,'XTickLabel',a,'fontsize',fontSizeAx)
ylabel('Effective labor tax, $\tau^\ell$', ...
    'interpreter','latex', 'fontsize', fontSizeAx)
xlabel('Effective tax on equipment and software,  $\tau^k$',...
    'interpreter','latex', 'fontsize', fontSizeAx)
% labelpoints(tau_k_to_label, tau_ell_to_label, tau_labels, ...
%     'interpreter','latex', 'fontsize', fontSizeAx, 'position', 'E')
xlim([0 .4])
ylim([0 .4])
yticks(0:.05:.4)
yticklabels(strcat('$', strsplit(num2str(0:5:40)), '\%$'))
xticks(0:0.05:.4)
xticklabels(strcat('$', strsplit(num2str(0:5:40)), '\%$'))
drawnow
labels = h.TextPrims;
for idx = 1 : numel(labels)
    LabelValue = str2double(labels(idx).String);
    h.TextPrims(idx).String = [num2str(round(LabelValue*100,2)) '\%'];
end

cd([current_path save_path_fig])
fname = strcat(fig_name);
set(gcf,'PaperPositionMode','auto');
hgsave(fname);
print('-depsc',fname,'-painters');
cd(current_path)
