import types
from scipy import *
from scipy import io,special
import numpy
import artsXML
import arts_scat
import arts
import os
from . import arts_types

PyARTS_SHARE_DIR=os.path.join(os.path.split(__file__)[0],
                              'share')

class TrainingDataGenerator:
    def __init__(self,atm_clear_file,atm_cloud_file,identifier,PSD_file):
        self.identifier=identifier
        self.atm_clear_file=atm_clear_file
        self.atm_cloud_file=atm_cloud_file
        self._extract_nonzero_iwc_T()
        self._generate_size_dist(PSD_file)
        
    
    def _extract_nonzero_iwc_T(self):
        """takes the chevallier profiles from the
        arts-xml-data package, and extracts all points with non-zero IWC,
        This method assigns the following data members that are used elsewhere
        by the class:
        
        n_profiles: The number of profiles in the atmospheric data set (.e.g. 5000
                    for Chevallier)

        n_pres:     The number of pressure levels in each profile (e.g. 92 for 
                    Chevallier)

        n_nonzero:  The number of points in all profiles of the atmospheric 
                    data set that have non-zero IWC

        non_zero_indices: is a tuple holding two arrays of indices, each of 
                    length n_nonzero, pointing to the profile number and 
                    pressure level of the non-zero IWC values.

        T_nonzero:  T values (C) corresponding to nonzero IWC points in all 
                    profiles (n_nonzero).

        iwc_nonzero: non zero IWC (in g/m^3) values in all profiles (n_nonzero)

        T_range:    range of temperatures (K) associated non-zero IWC.  This 
                    is needed for calculating scattering properties 
                    (refractive index)

        cbl_array:  a n_profile list of lists of 2 integers indicating the 
                    cloudbox limits to use in ARTS.

        """
        

        cloud_data=artsXML.load(self.atm_cloud_file)
        clear_data=artsXML.load(self.atm_clear_file)
        
        if type(cloud_data)==types.DictType:
            cloud_data=cloud_data['Array']
            clear_data=clear_data['Array']
            

        self.n_profiles=len(clear_data)
                            
        T=[];iwc=[]
        for i in range(self.n_profiles):
            T.append(clear_data[i]['Matrix'][:,1])
            iwc.append(cloud_data[i]['Matrix'][:,1])
    
        #This creates a n_profiles by N_pres array
        T=array(T)
        iwc=array(iwc)
        
        self.n_pres=T.shape[1]

        #keepers holds two arrays which hold the indices of nonzero iwc elements
        #the first array holds the profile indices, the second holds the 
        #pressure indices
        non_zero_indices=nonzero(iwc)

        #using numpys fancy indexing here, which I quite like, to extract the 
        #T and iwc for non-zero iwc
        self.T_nonzero=T[non_zero_indices]-273.15
        self.iwc_nonzero=iwc[non_zero_indices]*1000

        self.n_nonzero=len(self.T_nonzero)
        #Calculate temperature range, which is need for later scatting property 
        #calculation

        self.T_range=array((amin(amin(T[non_zero_indices])),amax(amax(T[non_zero_indices]))))

        
        self.cbl_array=[]
        #Write cloudbox limits
        for i in range(self.n_profiles):
            #for each profile the cloudbox limits are given by the first and
            #last pressure index.  This line is saying "give me the first and
            #last pressure index where the correponding profile index is i". 
            #The list cast is to make the following artXML.save command save 
            #this as an ArrayOfArrayOfIndex.
            non_zero_p_indices=non_zero_indices[1][nonzero(non_zero_indices[0]==i)]
            if len(non_zero_p_indices)==0:
                self.cbl_array.append([23,43])
                print "profile %d has no ice, adding bogus cloudbox limits" % i
            elif len(non_zero_p_indices)==1:
                self.cbl_array.append([max(10,non_zero_p_indices[0]-5),
                                       min(non_zero_p_indices[0]+5,40)])
            else:
                self.cbl_array.append(list(non_zero_p_indices[[0,-1]]))
        self.cloudbox_limits_file='%s_cblims.xml.gz' %self.identifier
        artsXML.save(self.cbl_array,self.cloudbox_limits_file)

        #keep non_zero_indices for later
        self.non_zero_indices=non_zero_indices

    def _generate_size_dist(self,PSD_file):
        """This finds gamma distribution parameters for each point
        in the chevallier data set by getting 
        the nearest data point in the Heymsfield data set.  The distance 
        squared between points in the two data sets is given by
        (Tc-Th)^2/sigma_t^2+(lnIc-lnIh)^2/sigma_lnI^2, where the subscripts 
        h and c represent the Heymsfield and Chevallier data sets, T is 
        temperature and lnI is the log of IWC.  

        Data members produced by this method:

        psds_nonzero : A (n_nonzero,n_size params) sized array, which holds 
                       the size distribution parameters for each nonzero 
                       IWC point.

        _volume_truth: This stores the hypothetic total particle volume for 
                       the chosen size distributions at each nonzero IWC 
                       point, assuming spherical particles.  (n_nonzero).  
                       This is used later to assess the accuracy of the 
                       binning scheme.

        """ 
        
        #first trim difficult parameters
        psds=numpy.loadtxt(PSD_file)
        

        #the following bit of code is only really for the original Heymsfield 
        #set (i.e. not for the cut down phony set used in 
        #PyARTS/examples/training_data_gen.py)

        if psds.shape[0]>4000:
            #lets throw away the top 1 percent of mu values
            
            mu_upper_cut=sort(psds[:,0])[-50]
            mu_lower_cut=sort(psds[:,0])[50]
            
            #and the top and bottom 1 percent of lam values
            
            lam_upper_cut=sort(psds[:,2])[-50]
            lam_lower_cut=sort(psds[:,2])[100]
            
            
            psds=psds[nonzero((psds[:,0]<mu_upper_cut)*\
                              (psds[:,0]>mu_lower_cut)*\
                              (psds[:,2]>lam_lower_cut)*\
                              (psds[:,2]<lam_upper_cut))[0]]

        sigma_T=std(self.T_nonzero)

        sigma_iwc=std(log(self.iwc_nonzero))
        
        #create 1xN_heyms normalised temperature and IWC arrays
        #(the redundant dimension is need later, and is added here
        #by way of the 3:4 and 4:5)
        T_h=transpose(psds[:,3:4])/sigma_T
        ln_iwc_h=transpose(log(psds[:,4:5]))/sigma_iwc
        
        T_c=reshape(self.T_nonzero/sigma_T,(self.n_nonzero,1))
        ln_iwc_c=reshape(log(self.iwc_nonzero)/sigma_iwc,(self.n_nonzero,1))

        

        
        #because the chevallier data set is big we want to avoid for looping
        #over each point, but also we cant do the whole set at once ...

        block_size=1000
        
        chosen_indices=array((),int)
        min_d=array((),int)

        for i in range(0,self.n_nonzero,block_size):
            block_end=min(i+block_size,self.n_nonzero+1)
            #This produces a block_size * N_heyms array of normalised
            #distances
            dsqrd=(T_h-T_c[i:block_end,:])**2 + (ln_iwc_h-ln_iwc_c[i:block_end,:])**2
            #add to an array of heymsfield indices pointing to the closest
            #heymsfield point to each chevallier point
            chosen_indices=r_[chosen_indices,argmin(dsqrd,axis=-1)]

        self.psds_nonzero=psds[chosen_indices,:]
        
        lam=self.psds_nonzero[:,2]
        mu=self.psds_nonzero[:,0]

        #Calculate the total particle volume at each non-zero IWC 
        #point for later verification of particle size binning.
        all_weights=[]
        for mu_i in mu:
            all_weights.append(real(special.genlaguerre(n=5,
                               alpha=mu_i+3).weights[:,1]))
        
        all_weights=array(all_weights)
        self._volume_truth=sum(all_weights,axis=-1)/lam**(4+mu)

        
    def plot_psds(self,interactive=False):
        """Makes a plot of the particle size distribution parameters (lambda
        and mu) for very non-zero IWC point in every profile.  The optional 
        argument interactive (bool) determines to display the plot, otherwise 
        it is saved to disk.  If the interactive option is chosen the figure 
        window must be closed before other commands are executed"""
        import maplotlib
        if not interactive:
            matplotlib.use('Agg')
        import pylab
        lam=self.psds_nonzero[:,2]
        mu=self.psds_nonzero[:,0]

        pylab.figure(figsize=(10,4))
        pylab.clf()
        
        pylab.subplot(121)

        pylab.scatter(self.T_nonzero,log10(self.iwc_nonzero),c=log10(lam),
                s=9,faceted=False,zorder=2)
        pylab.colorbar()
        pylab.title('(log$_{10}$ of gamma paramater $\lambda$')
        pylab.xlabel('T')
        pylab.ylabel('log$_{10}$(IWC)')
        
        
        pylab.subplot(122)
        
        pylab.scatter(self.T_nonzero,log10(self.iwc_nonzero),c=mu,s=9,
                faceted=False,zorder=2)
        pylab.colorbar()
        pylab.title('Gamma paramater $\mu$')
        pylab.xlabel('T')
        pylab.ylabel('log$_{10}$(IWC)')
        
        pylab.savefig('%s_psds.png' % self.identifier)
        if interactive:
            pylab.show()

    def find_basis_particle_sizes(self,n_sizes=None,tolerance=0.1,n_sizes_max=30,
                                  n_sizes_step=5):
        """Using Gauss Laguerre abscissa to determine *n_sizes* particle 
        sizes to represent the whole data set.  If n_sizes is not specified,
        several values of n_sizes are tried (up to *n_sizes_max*) until the 
        volume test results in fractional errors are everywhere less than 
        *tolerance*
        
        Data members produced by this method:

        n_sizes:   The number of discrete particle sizes used to represent
                   polydispersion.

        volume_frac_error:  This is an array of size n_nonzero that holds the 
                   fractional error in total partical volume for each nonzero 
                   IWC point using the chosen basis particle sizes.
        
        """
        if type(n_sizes)==types.NoneType:
            for n_sizes in range(5,n_sizes_max+1,n_sizes_step):
                self.find_basis_particle_sizes(n_sizes)
                if alltrue(self.volume_frac_error<tolerance):
                    break
        else:
            self.n_sizes=n_sizes
            lam=self.psds_nonzero[:,2:3]
            mu=self.psds_nonzero[:,0:1]

            lam_0=median(lam[:,0])
            mu_0=median(mu[:,0])
            
            A=special.genlaguerre(n=n_sizes,alpha=mu_0+3)
            
            x=transpose(real(A.weights[:,0:1]))
            w=transpose(real(A.weights[:,1:2]))
            
            
            estimate=sum(x**(mu-mu_0)*exp((1-lam/lam_0)*x)*w,
                         axis=-1)/lam_0**(4+squeeze(mu))
            self.volume_frac_error=(estimate-self._volume_truth)/self._volume_truth
            
            D=x/lam_0      #Maximum dimension in cm
            mass=2.94e-3*D**1.9  #mass in g
            self.equiv_radius=(3*squeeze(mass)/4/pi/0.917)**(1./3.)*1e4 #equivalent mass 
                                              #sphere radius (microns)

            print "maximum radius: %f microns" % self.equiv_radius[-1]
            print "which is a size parameter of %f" % (2*self.equiv_radius[-1]*180e9/3e14)
            
            #calculate PNDs dimensions self.n_nonzero by N_Sizes
            self.total_pnds_nonzero=w*D**mu*exp(x-lam*D)/lam_0/x**(3-mu_0)
            
            #scale to give correct IWC
            self.total_pnds_nonzero*=reshape(self.iwc_nonzero/sum(self.total_pnds_nonzero*mass,
                                                            axis=-1),
                                       (self.n_nonzero,1))

    def plot_frac_error(self,interactive=False):
        """Plot the fractional error in total particle volume test"""
        pylab.figure()
        pylab.clf()
        
        pylab.scatter(self.T_nonzero,log10(self.iwc_nonzero),
                      c=self.volume_frac_error,s=9,faceted=False)
        pylab.colorbar()
        pylab.title('fractional error in volume calculation: N=%d' % self.n_sizes)
        pylab.xlabel('T')
        pylab.ylabel('log$_{10}$(IWC)')
        
        pylab.savefig('%s_frac_error_%d.png' %(self.identifier,self.n_sizes))
        if interactive:
            pylab.show()

    def habit_frac_gen(self,centre=0.5,spread=0.25):
        """We assume two types of particles, spheres, and horizontally aligned
        oblate spheroids. This methods generates a N_nonzero by 2 array of habit 
        fractions. The second column (for the spheroid) is just one minus the 
        first (for the sphere).  The habit
        fraction is randomly chosen for each non-zero point according to a truncated 
        normal distribution with a centre of *centre* and stddev of *spread*"""
        habit_frac=random.normal(loc=centre,scale=spread,size=(self.n_nonzero))
        
        #redistribute the values that are outside 0-1 uniformly
        uniform=random.uniform(size=(self.n_nonzero))
        habit_frac=where(habit_frac<0,uniform,habit_frac)
        habit_frac=where(habit_frac>1,uniform,habit_frac)
        self.habit_frac=c_[habit_frac,1-habit_frac]
        self.n_habits=2

    def pnd_field_calc(self):
        """Fold the pnds calculated by find_basis_particle_sizes, multiplied by
        the habit fractions calculated by habit_grac_gen, back into the full
        n_pres profiles, trim to the cloudbox size, and save the array of 
        pnd_fields for use in ARTS"""
        
        #produce self.n_nonzero by self.n_habits by self.n_sizes array of pnds 
        #for each habit
        pnds_nonzero=reshape(self.total_pnds_nonzero,
                             (self.n_nonzero,1,self.n_sizes))*\
                             reshape(self.habit_frac,
                                     (self.n_nonzero,self.n_habits,1))

        #reshape to self.n_nonzero by N_habit*N_size
        # reshape with copy, data may be non-contiguous
        pnds_nonzero = pnds_nonzero.reshape(self.n_nonzero,self.n_habits*self.n_sizes)

        
        #now produce full pnd field, including all pressure levels
        pnds=zeros((self.n_profiles,self.n_pres,self.n_habits*self.n_sizes),float)
        #fancy indexing.  This is nice.
        pnds[self.non_zero_indices[0],self.non_zero_indices[1],:]=pnds_nonzero

        #Now need to produce an array of Tensor4 for ARTS
        self.pnd_fields=[]
        for i in range(self.n_profiles):
            cbl_bottom=self.cbl_array[i][0]
            cbl_top=self.cbl_array[i][1]
            self.pnd_fields.append(
                reshape(transpose(pnds[i,cbl_bottom:cbl_top+1,:]),
                        (self.n_habits*self.n_sizes,cbl_top-cbl_bottom+1,1,1))
                )
        self.batch_pnd_fields_file='%s_pnd_fields_%d.xml' %(self.identifier,
                                                            self.n_sizes)
        artsXML.save(self.pnd_fields,self.batch_pnd_fields_file)

    def scat_data_calc(self,f_grid,n_temp=2,num_proc=1,
                       aspect_ratios=array((1.00001,3.0))):
        """Calculate the single scattering properties for all self.n_habits x 
        self.n_sizes particles"""
        #This is a little bit horrible because of arts_scat.batch_generate's 
        #design - it accepts a dictionary of scat_params, where (almost) each 
        #value is a list and batch_generate cascades through each list giving 
        #scattering files for each combination.  Really it would be better for 
        #it to accept a list of dictionaries (pending fix)

        scat_params={'f_grid':[f_grid],
                     'za_grid':arange(0,190,10),
                     'aa_grid':arange(0,190,10),
                     'ptype':[30],
                     "T_grid":[linspace(self.T_range[0],self.T_range[1],n_temp)],
                     'phase':'ice',
                     'NP':[-1],       #Spheroid
                     'aspect_ratio':aspect_ratios,
                     'equiv_radius':self.equiv_radius
                     }
        scat_files=arts_scat.batch_generate(scat_params,num_proc)
        self.scat_data_list=[]
        for f in scat_files:
            self.scat_data_list.append(arts_types.SingleScatteringData.load(f))
        
        self.scat_data_file='%s_scat_data_raw_%d.xml.gz' %(self.identifier,
                                                           self.n_sizes)
        arts_types.ArrayOfSingleScatteringData(
            self.scat_data_list).save(self.scat_data_file)

    def run(self,ybatch_n=None,template_file=None):

        raise NotImplementedError("This stuff used template, to be updated!")
        if template_file is None:
            template_file=os.path.join(PyARTS_SHARE_DIR,
                                           'AMSUcloudyBatch.arts.tmplt')
        
        if ybatch_n is None:
            ybatch_n=self.n_profiles
        
        arts_params={'ybatch_n':ybatch_n,
                     'batch_pnd_fields':self.batch_pnd_fields_file,
                     'batch_cloudbox_limits':self.cloudbox_limits_file,
                     'scat_data_raw':self.scat_data_file,
                     'atm_clear_profiles':self.atm_clear_file}
        ybatch_file='%s_%s.ybatch.xml' %(self.identifier,self.n_sizes)
        if os.path.exists(ybatch_file):
            print 'Moving %s out of the way' %ybatch_file
            os.rename(ybatch_file,ybatch_file+'.old')

        self.arts_run=arts.ArtsRun(
                params=arts_params,run_type='batch',
                filename='%s_%s.arts' % (self.identifier,self.n_sizes),
                template_file=template_file).run()

        self.radiances=self.arts_run.output['ybatch']

                     
        
