import matplotlib as mpl
mpl.use('Agg')
mpl.rcParams['xtick.direction'] = 'in'
mpl.rcParams['ytick.direction'] = 'in'
mpl.rcParams['xtick.top'] = True
mpl.rcParams['ytick.right'] = True 
mpl.rcParams['xtick.top'] = True
mpl.rcParams['ytick.right'] = True 
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12
mpl.rcParams['axes.facecolor'] = 'w'  
mpl.rcParams['axes.edgecolor'] = 'k' 

mpl.rcParams['axes.labelsize'] = 12
mpl.rcParams['axes.titlesize'] = 14

mpl.rcParams['xtick.major.size'] = 4
mpl.rcParams['xtick.minor.size'] = 2
mpl.rcParams['ytick.major.size'] = 4
mpl.rcParams['ytick.minor.size'] = 2

from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator

import matplotlib.pyplot as plt

import numpy as np
import scipy.stats as st
from scipy import interpolate
import scipy.ndimage as ndimage

from astropy.io import fits

def make_cdf(f):

    shape = f.shape
       
    f_flat = f.flatten()

    ff = np.zeros_like(f_flat)
    #print(f_flat)
    
    arg = f_flat.argsort()
    #print(arg)
    
    f_sort = f_flat[arg]
    csum =  1- np.cumsum(f_sort)/np.sum(f_flat)
    #print(csum)
    
    for i in range(f_flat.size):
        ff[arg[i]] = csum[i]
    
    print (shape)
    return ff.reshape(shape)

def make_distr(X, Y, xmin, xmax, nx, ymin, ymax, ny, cdf=True):

    dx =  (xmax - xmin)/nx
    dy =  (ymax - ymin)/ny

    x = np.linspace(xmin+dx/2, xmax-dx/2, nx)
    y = np.linspace(ymin+dy/2, ymax-dy/2, ny)

    xedges = x-dx/2
    xedges = np.append(xedges, xedges[-1]+dx)
    yedges = y-dy/2
    yedges = np.append(yedges, yedges[-1]+dy)

    xx, yy =  np.meshgrid(x, y)

    f, xedges, yedges = np.histogram2d(X, Y, bins=(xedges, yedges))

    if cdf:
        f = f.T
        f = make_cdf(f) 

    return f, xx, yy

def load_int_sp():

    hdul1 = fits.open('../chains_sp1_4/1-4_gr1f_cplflux_1')
    hdul2 = fits.open('../chains_sp1_4/1-4_gr1f_cplflux_2')

    data1 = hdul1[1].data
    data2 = hdul1[1].data

    X = np.concatenate((data1['alpha__1'],data2['alpha__1']))
    Y = np.concatenate((data1['Epeak__2'],data2['Epeak__2']))

    imin = 10000
    imax = 4000000
    
    #np.savez_compressed('sp1_4', alpha=X[imin:imax], Ep=Y[imin:imax])
    #exit()
 

    return X[imin:imax], Y[imin:imax]

def load_sp():

    data = np.load('sp1_4.npz')

    X = data['alpha']
    Y = data['Ep']

    return X, Y

def get_contours(cs):

    lines = []
    for lines_c in cs.collections:
        print(lines_c)
        for line in lines_c.get_paths():
            lines.append(line.vertices)

    return lines

def make_pub_figure():

    Xint, Yint = load_sp()

    xmin, xmax, nx = -1.8, 1.0, 50
    ymin, ymax, ny =  50, 130, 50

    f_int, xx_int, yy_int = make_distr(Xint, Yint, xmin, xmax, nx, ymin, ymax, ny)
   
    
    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(5,5))
    
    pos1 = axes.imshow(f_int, cmap='gray', extent=[xmin, xmax, ymin, ymax], aspect='auto', origin='lower', vmin=0, vmax=1)
    

    fig.subplots_adjust(right=0.8)
    bbox_ax = axes.get_position()
    cbar_ax = fig.add_axes([0.83, bbox_ax.y0, 0.04, bbox_ax.y1-bbox_ax.y0])
    
    cbar = fig.colorbar(pos1, cax=cbar_ax)
    cbar.ax.set_ylabel('Probability')

    axes.set_xlim(xmin, xmax)
    axes.set_ylim(ymin, ymax)
    axes.xaxis.set_major_locator(MultipleLocator(0.5))
    axes.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    axes.xaxis.set_minor_locator(MultipleLocator(0.25))

    axes.yaxis.set_major_locator(MultipleLocator(10))
    axes.yaxis.set_major_formatter(FormatStrFormatter('%d'))
    axes.yaxis.set_minor_locator(MultipleLocator(5))

    #plt.grid(True)

    labels = ['68% CL', '95% CL','99.7% CL']

    levels = np.array([0.683, 0.954, 0.9973])
    lw=(0.5, 0.5, 0.5)

    f_int_sm = ndimage.gaussian_filter(f_int, sigma=0.6, order=0)
      
       
    CS1 = axes.contour(xx_int, yy_int, f_int_sm, levels, linewidths=lw, colors=('r', 'green', 'blue'))
    

    lst_ = get_contours(CS1)
    idx = np.argmin(lst_[1][:,0])
    min_alp = lst_[1][idx,0]
    min_ep = lst_[1][idx,1]
    print(min_alp, min_ep)

    for i in range(len(labels)):
        CS1.collections[i].set_label(labels[i])

    axes.legend(loc='upper right')


    axes.set_xlabel(r'$\alpha$')
    axes.set_ylabel(r'$E_\mathrm{p}$ (keV)')


  
    plt.savefig('sp_int.pdf')

if __name__ == "__main__":

    make_pub_figure()