"""This module defines the SLIData2 class, which allows the creation of
optimized grids for 2D sequential linear interpolation, as described by
[Changetal97]_ 
(`pdf <http://dynamo.ecn.purdue.edu/~bouman/publications/orig-pdf/ip9.pdf>`__).

The main difference with the method used by SLIData2 and that
described in the paper, is that we start with a course grid, and every
function evaluation is included in the final grid.  The motivation for
this is that is expected that function evaluations (e.g. ARTS RT
simulations) are expensive.

For an example of use, see the `arts.create_incoming_lookup`_ function. 

"""
from scipy import *
from scipy import interpolate, integrate
from .arts_geometry import EARTH_RADIUS
import arts_math
import artsXML
from physics import c,k
import tempfile

#these are just helper functions for SLIData2

def morepoints(interval_density,x,N):
    N0=len(x)
    Z=integrate.cumtrapz((N-1)*interval_density,x)
    Z0=arange(float(N0))
    newx=[]
    for i in range(N0-1):
        if Z[i]-Z0[i+1]>1:
            newx.append(atleast_1d(linspace(x[i],x[i+1],floor(Z[i]-Z0[i+1])+2))[1:-1])
            Z0[i+1:]+=floor(Z[i]-Z0[i+1])
    if len(newx)>0:
        return array(concatenate(newx))
    else:
        return array([])
    
def d2fdx2(x,y):
    N0=len(x)
    d2y=zeros(N0,float)
    dy=diff(y)
    dx=diff(x)
    d2y[1:-1]=2*(dy[1:]/dx[1:]-dy[:-1]/dx[:-1])/(dx[:-1]+dx[1:])
    d2y[0]=d2y[1]
    d2y[-1]=d2y[-2]
    return d2y

def call_func(func,x1,x2):
    #form N length vectors of x1 and x2
    x1a=[]
    x2a=[]
    y={}
    for x1i in x1:
        for x2i in x2[x1i]:
            x1a.append(x1i)
            x2a.append(x2i)
    ya=func.__call__(array(x1a),array(x2a))
    #now convert ya to x2 structure
    i=0
    for x1i in x1:
        a=[]
        for x2i in x2[x1i]:
           a.append(ya[i])
           i+=1
        y[x1i]=array(a)
    return y

def Bcalc(x1,x2,y):
    N0=len(x1)
        
    B=zeros(N0,float)
        
    for i in range(N0):
        d2ydx2=d2fdx2(x2[x1[i]],y[x1[i]])
        B[i]=float(integrate.simps((d2ydx2**2)**0.2,x2[x1[i]])**5)
        if B[i]==0:
            raise ValueError,"Function is too linear in y for x = "+str(x1[i])
    return B



class SLIData2:
    """Class for 2D sequential linear interpolation"""
    def __init__(self,func=None,x1=None,x2=None):
        """initialises the SLIData2 object with a standard grid defined by
        vectors x1 and x2.  func must be a function y=func(x1,x2), where
        x1,x2 and y are vectors of the same length"""
        self.func=func
        self.x1=x1
        self.x2={}
        if not func==None:
            for x in x1:
                self.x2[x]=x2
            self.y=call_func(self.func,self.x1,self.x2)


    def refine(self,N):
        """Refine the grid by increasing the total number of gridpoints
        to 'about' N.  Generally it is good to call refine 2 or more times
        to successively add more points to the grid.
        """
        x1=self.x1
        x2=self.x2
        y=self.y
        N0=len(x1)
        x2interpfuns={}
        #calculate A and B
        for xi in x1:
            x2interpfuns[xi]=interpolate.interp1d(x2[xi],y[xi])
        #B values
        B=Bcalc(x1,x2,y)
        A=zeros(N0,float)
        #A is a bit nastier
        for i in range(1,N0-1):
            x2i=x2[x1[i]]
            yi=y[x1[i]]
            d2zdx1=2*((x2interpfuns[x1[i+1]](x2i)-yi)/(x1[i+1]-x1[i])-\
                      (yi-x2interpfuns[x1[i-1]](x2i))/(x1[i]-x1[i-1]))/\
                      (-x1[i-1]+x1[i+1])
            A[i]=float(integrate.simps(d2zdx1**2,x2i))
        A[0]=A[1]
        A[-1]=A[-2]
        interval_density=(A**5/B)**(1/24.0)
        interval_density/=integrate.simps(interval_density,x1)
        #calculate desired n1
        n1=sqrt(N)*(integrate.simps(A/interval_density**4,x1)/\
                    integrate.simps(B**0.2*interval_density**0.8,x1)**5)**0.125+1
        #print n1
        #get new x1
        newx1=morepoints(interval_density,x1,n1)
        newx2={}
        #for each new x1 take the existing x2 values for i-1
        for x1i in newx1:
            newx2[x1i]=x2[x1[arts_math.locate(x1,x1i)]]
        newy=call_func(self.func,newx1,newx2)
        newB=Bcalc(newx1,newx2,newy)
        #merge new points
        x1=concatenate((x1,newx1))
        B=concatenate((B,newB))
        si=argsort(x1)
        x1=take(x1,si)
        B=take(B,si)
        x2.update(newx2)
        y.update(newy)
        
        #calculate desired n2
        N0=len(x1)
        z=zeros(N0+1,float)
        C=zeros(N0,float)
        n2=zeros(N0,float)
        z[1:-1]=(x1[:-1]+x1[1:])/2
        z[-1]=x1[-1]
        z[0]=x1[0]
        C=(B*(z[1:]-z[:-1]))**0.2
        #print z
        Csum=sum(C)
        n2=N*C/Csum
        #print n2
        #for each x1 calculate x2 interval density and new x2
        newx2={}
        for i in range(N0):
            x1i=x1[i]
            d2ydx2=d2fdx2(x2[x1i],y[x1i])
            interval_density=(d2ydx2**2)**0.2/integrate.simps((d2ydx2**2)**0.2,x2[x1i])
            newx2[x1i]=morepoints(interval_density,x2[x1i],n2[i])
        newy=call_func(self.func,x1,newx2)
        #merge new points
        for x1i in x1:
            if len(newx2[x1i])>0:
                x2[x1i]=concatenate((x2[x1i],newx2[x1i]))
                y[x1i]=concatenate((y[x1i],newy[x1i]))
                si=argsort(x2[x1i])
                x2[x1i]=take(x2[x1i],si)
                y[x1i]=take(y[x1i],si)
        
        self.x1=x1
        self.x2=x2
        self.y=y
        return self

    def plot(self):
        """create a simple scatter plot of the grid."""
##        from plotting import *
        miny=[]
        maxy=[]
        for x1i in self.x1:        
            miny.append(min(self.y[x1i]))
            maxy.append(max(self.y[x1i]))
        miny=min(miny)
        maxy=max(maxy)
        for x1i in self.x1:
            h=scatter(self.x2[x1i],x1i*ones(len(self.x2[x1i])),
                    c=self.y[x1i],vmin=miny,vmax=maxy)
            h.set_edgecolor([(0,0,0,0)])#makes edges invisible
        cb=colorbar()#tickfmt='%1.1e')
                

    def save(self,filename):
        """output the SLIData2 object in ARTS XML format"""
        f=artsXML.XMLfile(filename)
        f.write('<SLIData2>\n')
        x2list=[]
        ylist=[]
        for x1 in self.x1:
            x2list.append(self.x2[x1])
            ylist.append(self.y[x1])
        f.add(self.x1)
        f.add(x2list)
        f.add(ylist)
        f.write('</SLIData2>\n')
        f.close()

    def load(self,filename):
        """reads SLIData2 object from an XML file"""
        data=artsXML.load(filename)['SLIData2']
        self.x1=data['Vector']
        #x2list=[]
        #ylist=[]
        n=len(self.x1)
        self.x2={}
        self.y={}
        for i in range(n):
            self.x2[self.x1[i]]=data['Array'][i]['Vector']
            self.y[self.x1[i]]=data['Array 0'][i]['Vector']
        return self

    def interp(self,x1,x2):
        """interpolate SLIData2 at x1 and x2 (single numeric values only)"""
        l1=arts_math.locate(self.x1,x1)
        if l1 in [-1,len(self.x1)-1]:
            raise ValueError, "x1 is out of range"
        x1l=self.x1[l1]
        x1r=self.x1[l1+1]
        wr=(x1-x1l)/(x1r-x1l)
        wl=1-wr
        try:
            yl=interpolate.interp1d(self.x2[x1l],self.y[x1l])(x2)[0]
            yr=interpolate.interp1d(self.x2[x1r],self.y[x1r])(x2)[0]
        except ValueError:
            raise ValueError, "x1,x2 out of range"
        return wl*yl+wr*yr


        
    
            

