import numpy as np
import matplotlib.pyplot as mpl
import matplotlib.widgets as mpw
import galaxy_simulator as galSIM
import scipy.ndimage as snd
import scipy.optimize as sco

#==============================================================================
#  Create Mock Galaxy Data
#==============================================================================

galaxy1 = galSIM.create_mock_galaxy()
galaxy2 = galSIM.create_mock_galaxy()
spectra1D,spectra2D = galSIM.create_mock_spectra()
wavelength = np.arange(np.size(spectra1D))*5.0+3500. ## create mock wavelength range

## Create zoom in version of spectra aroun 6000AA with 1000AA width
wave_zoom = 5000
wave_window = 300
zoom_selection = (wavelength>wave_zoom-wave_window)*(wavelength<wave_zoom+wave_window)
spectra1Dzoom = spectra1D[zoom_selection]
spectra2Dzoom = spectra2D[:,zoom_selection]
wavelength_zoom = wavelength[zoom_selection]

#==============================================================================
# Interactive Actions 
#==============================================================================
def smooth_1Dspectra(val):
    speczoom.set_data([wavelength_zoom,snd.gaussian_filter(spectra1Dzoom,val)])                
    fig.canvas.draw_idle()
    return None
    
def gaussian(x,a,c,x0,sigma):
    return a*np.exp(-(x-x0)**2/(2*sigma**2))+c

def fit_gaussian_to_data():
    N=np.size(wavelength_zoom)
    pars,err = sco.curve_fit(gaussian,wavelength_zoom,spectra1Dzoom,p0=(1,3,wavelength_zoom[N/2],5))
    gaussline.set_data([wavelength_zoom,gaussian(wavelength_zoom,*pars)])
    fig.canvas.draw_idle()
    return None
    
def change_contrast1(val):
    change_contrast(val,axImage1)
    return None

def change_contrast2(val):
    change_contrast(val,axImage2)
    return None
    
def change_contrast(val,eixo):
    data = eixo.get_images()[0]
    VMIN = np.amin(data.get_array())
    VMAX = np.amax(data.get_array())
    data.set_clim(vmin=VMIN,vmax=val*(VMAX-VMIN)+VMIN)
    fig.canvas.draw_idle()
    return None

def trigger_keyboard(event):
    if event.key == 'f':
       fit_gaussian_to_data()
    if event.key=='escape':
        sys.exit()
    if event.key=='q':
        mpl.close('all')
    return None
            
def reset_data(event):
    slider_smooth.reset()
    slider_contrast2.reset()
    slider_contrast1.reset()
    gaussline.set_data([[],[]])
    return None

def get_spectral_position(event):
    global wavelength_zoom,spectra1Dzoom,spectra2Dzoom
    if event.inaxes == axFullSpec1D:
        wave_selected=event.xdata
        zoom_selection = (wavelength>wave_selected-wave_window)*(wavelength<wave_selected+wave_window)
        spectra1Dzoom = spectra1D[zoom_selection]
        spectra2Dzoom = spectra2D[:,zoom_selection]
        wavelength_zoom = wavelength[zoom_selection]
        speczoom.set_data([wavelength_zoom,spectra1Dzoom])
        speczoomimage.set_data(spectra2Dzoom)
        speczoomimage.set_extent([wavelength_zoom[0],wavelength_zoom[-1],0,spectra2Dzoom.shape[0]])
        axZoomSpec1D.set_xlim(wavelength_zoom[0],wavelength_zoom[-1])
        for collection in (axFullSpec1D.collections):
            axFullSpec1D.collections.remove(collection)
        axFullSpec1D.fill_betweenx([0.95*np.amin(spectra1D),1.05*np.amax(spectra1D)],wave_selected-wave_window,wave_selected+wave_window,color='LimeGreen',alpha=0.25,zorder=-1)
        fig.canvas.draw_idle()
    return None

#==============================================================================
# Create Figure
#==============================================================================

fig = mpl.figure(figsize=(18,12))


axFullSpec1D = fig.add_axes([0.22,0.70,0.73,0.15])
axFullSpec2D = fig.add_axes([0.22,0.85,0.73,0.10],sharex=axFullSpec1D)
axZoomSpec1D = fig.add_axes([0.22,0.15,0.73,0.55])
axZoomSpec2D = fig.add_axes([0.22,0.05,0.73,0.1],sharex=axZoomSpec1D)
axImage1 = fig.add_axes([0.01,0.65,0.19,0.28])
axImage2 = fig.add_axes([0.01,0.20,0.19,0.28],sharex=axImage1,sharey=axImage1)
        


#==============================================================================
# Plot Spectra
#==============================================================================
axFullSpec1D.plot(wavelength,spectra1D,'k-',lw=2,drawstyle='steps-mid')
axFullSpec2D.imshow(spectra2D,cmap='viridis',vmin=-1,vmax=10,extent=(wavelength[0],wavelength[-1],0,spectra2D.shape[0]),aspect='auto')
axFullSpec1D.fill_betweenx([0.95*np.amin(spectra1D),1.05*np.amax(spectra1D)],wave_zoom-wave_window,wave_zoom+wave_window,color='LimeGreen',alpha=0.25,zorder=-1)


speczoom,=axZoomSpec1D.plot(wavelength_zoom,spectra1Dzoom,'-',color='DodgerBlue',lw=2,drawstyle='steps-mid')
gaussline,=axZoomSpec1D.plot([], [],color='red',ls='--',linewidth=2)        

speczoomimage=axZoomSpec2D.imshow(spectra2Dzoom,cmap='viridis',vmin=-1,vmax=10,extent=(wavelength_zoom[0],wavelength_zoom[-1],0,spectra2Dzoom.shape[0]),aspect='auto')

axFullSpec2D.tick_params(labelleft='off',labelright='off',labelbottom='off',labeltop='on')
axFullSpec1D.tick_params(labelleft='off',labelright='off',labelbottom='off',labeltop='off')
axZoomSpec1D.tick_params(labelleft='off',labelright='on',labelbottom='off',labeltop='off')
axZoomSpec2D.tick_params(labelleft='off',labelright='off',labelbottom='on',labeltop='off')

## LIMITS FULL SPECTRA
axFullSpec1D.set_ylim(0.95*np.amin(spectra1D),1.05*np.amax(spectra1D))
axFullSpec1D.set_xlim(wavelength[0],wavelength[-1])

## LIMITS ZOOM SPECTRA
axZoomSpec1D.set_ylim(0.95*np.amin(spectra1Dzoom),1.05*np.amax(spectra1Dzoom))
axZoomSpec1D.set_xlim(wavelength_zoom[0],wavelength_zoom[-1])

smooth_slid=fig.add_axes([0.30,0.64,0.16,0.02])
slider_smooth = mpw.Slider(smooth_slid,'Smooth', 0.0, 3.0, valinit=0.0,color='HotPink',closedmin=True)
slider_smooth.on_changed(smooth_1Dspectra)
        
        
#==============================================================================
# Plot Images
#==============================================================================
HalfSize_ArcSec=3

axImage1.imshow(galaxy1,cmap='magma',extent=(-HalfSize_ArcSec,HalfSize_ArcSec,-HalfSize_ArcSec,HalfSize_ArcSec),aspect='equal')
axImage2.imshow(galaxy2,cmap='magma',extent=(-HalfSize_ArcSec,HalfSize_ArcSec,-HalfSize_ArcSec,HalfSize_ArcSec),aspect='equal')

##LIMITS IMAGE
axImage2.set_xlim(-HalfSize_ArcSec,HalfSize_ArcSec)
axImage2.set_ylim(-HalfSize_ArcSec,HalfSize_ArcSec)

contrast_slid1=fig.add_axes([0.045,0.95,0.12,0.02])
slider_contrast1 = mpw.Slider(contrast_slid1,'Contr.', 0.0, 1.0, valinit=1.0,color='Gold',closedmin=False)
slider_contrast1.on_changed(change_contrast1)
        
contrast_slid2=fig.add_axes([0.045,0.51,0.12,0.02])
slider_contrast2 = mpw.Slider(contrast_slid2,'Contr.', 0.0, 1.0, valinit=1.0,color='Gold',closedmin=False)
slider_contrast2.on_changed(change_contrast2)


#==============================================================================
# Added Buttons
#==============================================================================

resbut = fig.add_axes([0.015,0.05,0.1,0.05])
resbut=mpw.Button(ax=resbut,label='Reset',color='Gold',hovercolor='GoldenRod')
resbut.on_clicked(reset_data)

#==============================================================================
# Connect to figure  
#==============================================================================

fig.canvas.mpl_connect('key_press_event',trigger_keyboard)
fig.canvas.mpl_connect('button_press_event',get_spectral_position)

mpl.show()
