function [PRG,POG,FIG,AX,PRLIN,POLIN,BLIN,TIT] = plotpp(PR,PO,varargin)
% plotpp  Plot prior and/or posterior distributions and/or posterior mode.
%
% Syntax
% =======
%
%     [PRG,POG,FIG,AX,PRLIN,POLIN,BLIN,PRANN] = grfun.plotpp(E,THETA,...)
%     [PRG,POG,FIG,AX,PRLIN,POLIN,BLIN,PRANN] = grfun.plotpp(E,ST,...)
%     [PRG,POG,FIG,AX,PRLIN,POLIN,BLIN,PRANN] = grfun.plotpp(E,PEST,...)
%
% Input arguments
% ================
%
% * `E` [ struct ] - Estimation input struct, see
% [`estimate`](model/estimate), with prior function handles from the
% [logdist](logdist/Contents) package.
%
% * `THETA` [ numeric | empty ] - Array with the chain of draws from the
% posterior simulator [`arwm`](poster/arwm).
%
% * `ST` [ struct | empty ] - Output struct returned by the posterior
% simulator statistics function [`stats`](poster/stats).
%
% * `PEST` [ struct | empty ] - Output struct returned by the
% [`model/estimate`](model/estimate) function containing the posterior mode
% estimates.
%
% Output arguments
% =================
%
% * `PRG` [ struct ] - Struct with x- and y-axis coordinates to plot the
% prior distribution for each parameter.
%
% * `POG` [ struct ] - Struct with x- and y-axis coordinates to plot the
% posterior distribution for each parameter.
%
% * `FIG` [ numeric ] - Handles to the figures created.
%
% * `AX` [ numeric ] - Handles to the axes (graphs) created.
%
% * `PRLIN` [ numeric ] - Handles to the prior lines plotted.
%
% * `PRLIN` [ numeric ] - Handles to the posterior lines plotted.
%
% * `BLIN` [ numeric ] - Handles to the lower and upper bound lines
% plotted.
%
% * `TIT` [ numeric ] - Handles to the graph titles created.
%
% Options
% ========
%
% * `'describePrior='` [ *'auto'* | true | false ] - Add one extra line to
% each graph title describing the prior (name, mean, and std dev); if
% `'auto='` the description will be shown only if `'plotPrior='` is true.
%
% * `'plotPrior='` [ *`true`* | `false` ] - Plot prior distributions.
%
% * `'plotPoster='` [ *`true`* | `false` ] - Plot posterior distributions.
%
% * `'plotBounds='` [ *`true`* | `false` ] - Add lower and/or upper bounds to
% the distribution graphs; if false, the bounds are only added if they are
% within the graph x-limits.
%
% * `'sigma='` [ numeric | *3* ] - Number of std devs from the mean or the
% mode (whichever covers a larger area) to the left and to right that will
% be plotted unless running out of bounds.
%
% * `'tight='` [ *`true`* | `false` ] - Make graph axes tight.
%
% * `'xLims='` [ struct | *empty* ] - Control the x-limits of the prior and
% posterior graphs.
%
% Description
% ============
%
% If you call `plotpp` with `PEST` (i.e. a struct with posterior mode
% estimates) as the second argument, the posterior modes are plotted as
% vertical lines (stem graphs).
%
% Example
% ========
%

% -IRIS Toolbox.
% -Copyright (c) 2007-2012 Jaromir Benes.

% irisopt.grfun
options = passvalopt('grfun.plotpp',varargin{:});

try
    PO; %#ok<VUNUS>
catch
    PO = [];
end

if isempty(PO)
    options.plotposter = false;
end

if isequal(options.describeprior,'auto')
    options.describeprior = options.plotprior;
end

%**************************************************************************

w = options.sigma;

% Get lower and upper bounds for individual params.
b = plt_getbounds(PR);

% Get prior function handles.
prf = plt_getprfunc(PR);

% Get x-limits for individual priors.
prxlim = plt_getprxlims(prf,b,w,options.xlims);

% Compute x- and y-axis co-ordinates for prior graphs.
PRG = plt_getprgraphs(prf,prxlim);

% Compute x- and y-axis co-ordinates for posterior graphs.
POG = plt_getpographs(PO,b,w,PRG);

% Get x-limits for individual posteriors.
poxlim = plt_getpoxlims(POG); %#ok<NASGU>

% We're done if actual plots are not requested.
if ~options.plotprior && ~options.plotposter
    FIG = [];
    AX = [];
    PRLIN = [];
    return
end

% Create titles.
descript = plt_createtitles(prf,POG,options);

% Create graphs and plot the lower and upper bounds if requested.
[FIG,AX,TIT] = plt_creategraphs(PR,descript,options.subplot);

% Plot priors.
PRLIN = [];
if options.plotprior
    PRLIN = plt_plotpriors(AX,PRG,options.tight);
end

% Plot posteriors.
POLIN = [];
if options.plotposter
    POLIN = plt_plotposter(AX,POG,options.tight);
end

% Plot bounds as vertical lines.
BLIN = plt_plotbounds(AX,b,options.plotbounds);
set(BLIN(isfinite(BLIN)),'color','red');

end

% Subfunctions.

%**************************************************************************
function B = plt_getbounds(PR)

list = fieldnames(PR);
nlist = numel(list);
B = struct();
for i = 1 : nlist
    pr = PR.(list{i});
    low = -Inf;
    upp = Inf;
    try
        low = pr{2};
    catch
        try %#ok<*TRYNC>
            low = pr(2);
        end
    end
    try
        upp = pr{3};
    catch %#ok<*CTCH>
        try
            upp = pr(3);
        end
    end
    B.(list{i}) = [low,upp];
end

end
% plt_getbounds().

%**************************************************************************
function PRF = plt_getprfunc(PR)

list = fieldnames(PR);
nlist = numel(list);
PRF = struct();
for i = 1 : nlist
    try
        PRF.(list{i}) = PR.(list{i}){4};
    catch
        PRF.(list{i}) = [];
    end
end

end
% plt_getprfunc().

%**************************************************************************
function PRXLIM = plt_getprxlims(PRF,B,W,USRXLIMS)

list = fieldnames(PRF);
nlist = numel(list);
PRXLIM = struct();
for i = 1 : nlist
    f = PRF.(list{i});
    from = NaN;
    to = NaN;
    try
        from = double(USRXLIMS.(list{i})(1));
        to = double(USRXLIMS.(list{i})(1));
    end
    if (isnan(from) || isnan(to) ) && ~isempty(f)
        low = B.(list{i})(1);
        upp = B.(list{i})(2);
        mean = f([],'mean');
        sgm = f([],'sigma');
        mode = f([],'mode');
        from = min([mean-W*sgm,mode-W*sgm]);
        from = max([from,low]);
        to = max([mean+W*sgm,mode+W*sgm]);
        if ~isfinite(to)
            to = max([W*mean,W*mode]);
        end
        to = min([to,upp]);
    end
    PRXLIM.(list{i}) = [from,to];
end

end
% plt_getprxlims().

%**************************************************************************
function PRG = plt_getprgraphs(PRF,PRXLIM)

list = fieldnames(PRF);
nlist = numel(list);
PRG = struct();
for i = 1 : nlist
    f = PRF.(list{i});
    if isempty(f)
        x = NaN;
        y = NaN;
    else
        from = PRXLIM.(list{i})(1);
        to = PRXLIM.(list{i})(2);
        x = linspace(from,to,1000);
        y = f(x,'proper');
    end
    PRG.(list{i}) = {x,y};
end

end
% plt_getrpxlim().

%**************************************************************************
function [FIG,AX,TIT] = plt_creategraphs(PR,DESCRIPT,NSUB)

list = fieldnames(PR);
nlist = numel(list);
if isequal(NSUB,'auto')
    NSUB = ceil(sqrt(nlist));
    if NSUB*(NSUB-1) >= nlist
        NSUB = [NSUB-1,NSUB];
    else
        NSUB = [NSUB,NSUB];
    end
end

total = prod(NSUB);
FIG = figure();
AX = nan(1,nlist);
TIT = nan(1,nlist);
pos = 1;
for i = 1 : nlist
    if pos > total
        FIG = [FIG,figure()]; %#ok<AGROW>
        pos = 1;
    end
    AX(i) = subplot(NSUB(1),NSUB(2),pos);
    TIT(i) = title(DESCRIPT{i},'interpreter','tex');
    hold(AX(i),'all');
    pos = pos + 1;
end
grfun.clicktocopy(AX);

end
% plt_creategraphs().

%**************************************************************************
function PRLIN = plt_plotpriors(AX,PRG,tight)

list = fieldnames(PRG);
nlist = numel(list);
PRLIN = [];
for i = 1 : nlist
    prg = PRG.(list{i});
    PRLIN = [PRLIN,plot(AX(i),prg{:})]; %#ok<AGROW>
    if tight
        grfun.yaxistight(AX(i));
    end
    grid(AX(i),'on');
end

end
% plt_plotpriors().

%**************************************************************************
function POLIN = plt_plotposter(AX,POG,tight)

list = fieldnames(POG);
nlist = numel(list);
POLIN = [];
for i = 1 : nlist
    pog = POG.(list{i});
    if length(pog{1}) == 1
        polin = stem(AX(i),pog{:});
    else
        polin = plot(AX(i),pog{:});
    end
    POLIN = [POLIN,polin]; %#ok<AGROW>
    if tight
        grfun.yaxistight(AX(i));
    end
    grid(AX(i),'on');
end

end
% plt_plotposter().

%**************************************************************************
function BLIN = plt_plotbounds(AX,B,forceplot)

list = fieldnames(B);
nlist = numel(list);
BLIN = [];
for i = 1 : nlist
    xlim = get(AX(i),'xLim');
    low = B.(list{i})(1);
    upp = B.(list{i})(2);
    if ~forceplot
        if low < xlim(1) || low > xlim(2)
            low = NaN;
        end
        if upp < xlim(1) || upp > xlim(2)
            upp = NaN;
        end
    end
    BLIN = [BLIN,grfun.vline(AX(i),[low,upp])]; %#ok<AGROW>
end

end
% plt_plotpriors().

%**************************************************************************
function POG = plt_getpographs(PO,B,W,PRG) %#ok<INUSL>

W = 5;
list = fieldnames(B);
nlist = numel(list);
for i = 1 : nlist
    try
        POG = PO.ksdensity.(list{i});
        continue
    end
    low = B.(list{i})(1);
    upp = B.(list{i})(2);
    if isinf(upp) && ~isinf(low) && low ~= 0
        if low > 0
            low = 0;
        else
            upp = 1e10;
        end
    elseif isinf(low) && ~isinf(upp) && upp ~= 0
        if upp < 0
            upp = 0;
        else
            low = -1e10;
        end
    end
    theta = do_getchain();
    if ~isempty(theta)
        try
            [y,x] = ksdensity(theta,'support',[low,upp]);
        catch
            try
                [ans,y,x] = thirdparty.kde(theta,2^10,low,upp); %#ok<NOANS,ASGLU>
            catch
                n = length(theta);
                [y,x] = hist(theta,max([2,round(n/50)]));
                width = x(2) - x(1);
                y = (y/n) / width;
            end
        end
        Mean = mean(theta);
        Std = std(theta);
        index = x < Mean-W*Std | x > Mean+W*Std;
        x(index) = [];
        y(index) = [];
    else
        try
            x = PO.(list{i});
            y = 0.98*max(PRG.(list{i}){2});
            if isnan(y)
                % This happens if there's no prior distribution on this
                % parameter.
                y = 1;
            end
        catch
            x = [];
            y = [];
        end
    end
    POG.(list{i}) = {x,y};
end

    function theta = do_getchain()
        theta = [];
        if isnumeric(PO)
            try
                theta = PO(i,:);
            end
        else
            try
                theta = PO.chain.(list{i});
            end
        end
    end
% do_getchain

end
% plt_getposgraph().

%**************************************************************************
function POXLIM = plt_getpoxlims(POG)

list = fieldnames(POG);
nlist = numel(list);
POXLIM = struct();
for i = 1 : nlist
    pog = POG.(list{i});
    from = min(pog{1});
    to = max(pog{1});
    POXLIM.(list{i}) = [from,to];
end

end
% plt_getpoxlims().

%**************************************************************************
function TIT = plt_createtitles(PRF,POG,options)

list = fieldnames(PRF);
nlist = numel(list);
TIT = cell(1,nlist);
for i = 1 : nlist
    TIT{i} = strrep(list{i},'_','\_');
    TIT{i} = ['{\bf',TIT{i},'}'];
    if ~options.describeprior
        continue
    end
    f = PRF.(list{i});
    if isempty(f)
        TIT{i} = [TIT{i},sprintf('\nprior: flat')];
    else
        try
            name = f([],'name');
            mu = f([],'mean');
            sgm = f([],'std');
            TIT{i} = [TIT{i}, ...
                sprintf('\nprior: %s {\\mu=}%g {\\sigma=}%g', ...
                name,mu,sgm)];
        end
    end
    try
        pog = POG.(list{i}){1};
        if length(pog) == 1
            TIT{i} = [TIT{i}, ...
                sprintf('\nposter: {mode}=%g',pog)];
        end
    end
end
end
% plt_createtitles().