"""
This module includes general purpose math functions.  This module was developed before
scipy was included as a prerequisite, so there will be some functions remaining that
duplicate scipy functionality. 
"""
#from scipy import *
import scipy
import scipy.stats
import scipy.integrate
import types
import copy
import general
import numpy

class ConvError(general.PyARTSError):pass
class OutOfRangeError(general.PyARTSError):pass


GaussLegData={
    2:{'a':[0.57735026919],
       'w':[1]},
    3:{'a':[0,0.77459667],
       'w':[0.88888889,0.55555555]},
    4:{'a':[0.33998104,0.86113631],
       'w':[0.65214515,0.34785485]},
    5:{'a':[0,0.53846931,0.90617985],
       'w':[0.56888889,0.47862867,0.23692689]},
    6:{'a':[0.23861918,0.66120939,0.93246951],
       'w':[0.46791393,0.36076157,0.17132449]},
    7:{'a':[0.0,0.40584515,0.74153119,0.94910791],
       'w':[0.41795918,0.38183005,0.27970539,0.12948497]},
    8:{'a':[0.18343464,0.52553241,0.79666648,0.96028986],
       'w':[0.36268378,0.31370665,0.22238103,0.10122854]},
    10:{'a':[0.14887434,0.43339539,0.67940957,0.86506337,0.97390653],
        'w':[0.29552422,0.26926672,0.21908636,0.14945135,0.06667134]}
}

LagGaussData={
    2:{'a':[0.585786,3.41421],
       'w':[0.853553,0.146447]},
    3:{'a':[0.415775,2.29428,6.28995],
       'w':[0.711093,0.278518,0.0103893]},
    4:{'a':[0.322548,1.74576,4.53662,9.39507],
       'w':[0.603154,0.357419,0.0388879,0.000539295]},
    5:{'a':[0.26356,1.4134,3.59643,7.08581,12.6408],
       'w':[0.521756,0.398667,0.0759424,0.00361176,0.00002337]}
}

def array_simpson(data,h):
    """takes a array and step-size h and returns
    the integral, along index, estimated by the the simpsons rule"""
    assert(len(data.shape)==1),"array_trapzd only works for 1_D arrays"
    n=len(data)
    assert(numpy.remainder(n,2)==1),"Simpsons only works for odd lengthed vectors"
    f0=data[0]
    fn=data[n-1]
    fodd=0.0
    feven=0.0
    for i in range(1,n-1,2):
        fodd+=data[i]
    for i in range(2,n-2,2):
        feven+=data[i]

    integral=(f0+2*feven+4*fodd+fn)*h/3
    return integral


def array_trapzd(data,h):
    """takes a array and step-size h and returns
    the integral, along index, estimated by the the trapezium rule"""
    assert(len(data.shape)==1),"array_trapzd only works for 1_D arrays"
    n=len(data)
    f0=data[0]
    fn=data[n-1]
    fmiddle=0.0
    for i in range(1,n-1):
        fmiddle+=data[i]
    integral=(f0/2+fmiddle+fn/2)*h
    return integral


def array_order4(data,h):
    assert(len(data.shape)==1),"array_trapzd only works for 1_D arrays"
    n=len(data)
    fmiddle=0.0
    for i in range(3,n-3):
        fmiddle+=data[i]
    integral=(3*data[0]/8+7*data[1]/6+23*data[2]/24+fmiddle+\
              3*data[n-1]/8+7*data[n-2]/6+23*data[n-3]/24)*h
    return integral


def gauss_leg(func,a,b,n):
    """Gauss legendre integration with n abscissa"""
    c=0.5*(b+a)
    m=0.5*(b-a)
    intsum=0.0
    for i in range(len(GaussLegData[n]['a'])):
        abscissa=GaussLegData[n]['a'][i]
        weight=GaussLegData[n]['w'][i]
        intsum+=weight*func(c+m*abscissa)
        if not abscissa==0:
            intsum+=weight*func(c-m*abscissa)
    integral=m*intsum
    return integral


def gridmerge(aa,ba):
    """Merges two sorted vectors(numpy array objects)"""
    assert(len(aa.shape)==1)
    assert(len(ba.shape)==1)

    for b in ba:
        i=locate(aa,b)

        #only insert value if it isn't already there
        if not b==aa[i]:
            al=aa.tolist()
            al.insert(i+1,b)
            aa=numpy.array(al)
    return aa

def interp(xa,ya,x):
    """performs linear interpolation of ya. x can be a float, Int, Array, or List.
    Returns an Array of floats.
    
    See also: GriddedField.Interpolate
    """
    if type(x)==types.FloatType:
        x=numpy.array((x,))
    else:
        x=numpy.array(x)
    y=numpy.zeros(x.shape,float)
    for i in range(len(x)):
        ia=locate(xa,x[i])
        if ia == -1 or ia == len(xa) -1:
            raise OutOfRangeError,'x is outside range of xa'
        y[i],error=polint(xa[ia:ia+2],ya[ia:ia+2],x[i])
    return y


def interp1DFieldByZkm(zfile,datafile,zkm):
    """interpolates a arts 1D field (specified by a file name) by geometric
    altidude (in km)"""
    import arts_types
    zm=zkm*1e3
    zfield=arts_types.GriddedField3.load(zfile)
    datafield=arts_types.GriddedField3.load(datafile)
    #find log(p) corresponding to zkm
    log_p=interp(zfield.data.squeeze(),numpy.log(zfield.p_grid),zm)
    return interp(numpy.log(datafield.p_grid),datafield.data.squeeze(),log_p)




def lag_gauss(func,n):
    """performs Laguerre-Gauss quadrature i.e estimates the integral
    \int_0^\infty\exp(-x)f(x)dx"""
    integral=0.0
    for i in range(n):
        integral+=LagGaussData[n]['w'][i]*func(LagGaussData[n]['a'][i])
    return integral

def locate(xa,x):
    """Given an array, xa, and a number x, locate returns the index i, such
    that x lies between xa[i] and xi[j]. Answers of -1 or n-1 indeicate that
    x is beyond the range of xx"""
    il=-1
    n=len(xa)
    iu=n
    ascend=xa[-1]>xa[0]

    while ((iu-il)>1):
        im=iu+il>>1
        if((x>=xa[im])==ascend):
            il=im
        else:
            iu=im
    if (x==xa[0]):
        i=0
    elif (x==xa[n-1]):
        i=n-2
    else:
        i=il
    return i

def multi_qromb(func,rangelist,EPS=0.0001,JMAX=20,K=5):
    if len(rangelist)>1:
        def func1(arglist1):
            def func2(xn):
                arglist2=copy.deepcopy(arglist1)
                arglist2.append(xn)
                return func(arglist2)
            integral,error=qromb(func2,rangelist[-1][0],rangelist[-1][1],EPS,
                                 JMAX,K)
            return integral
        return multi_qromb(func1,rangelist[:-1],EPS,JMAX,K)
    else:
        def func1(x):
            return func([x])
        integral,error=qromb(func1,rangelist[-1][0],rangelist[-1][1],EPS,
                             JMAX,K)
        return integral

def multi_gauss_leg(func,rangelist,n=10):
    if len(rangelist)>1:
        def func1(arglist1):
            def func2(xn):
                arglist2=copy.deepcopy(arglist1)
                arglist2.append(xn)
                return func(arglist2)
            integral=gauss_leg(func2,rangelist[-1][0],rangelist[-1][1],n)
            return integral
        return multi_gauss_leg(func1,rangelist[:-1],n)
    else:
        def func1(x):
            return func([x])
        integral=gauss_leg(func1,rangelist[-1][0],rangelist[-1][1],n)
        return integral


def nlinspace(start,stop,n):
    """Returns a linearly spaced vector between start and stop of length n
    """
    return numpy.linspace(start, stop, n)

def nlogspace(start,stop,n):
    """Identical to the function of the same name in ARTS; Returns a vector
    logarithmically spaced vector between start and stop of length n
    (equals the Matlab function logspace)"""
    return numpy.logspace(numpy.log10(start), numpy.log10(stop), n)

def polint(xa,ya,x):
    """Copied from the fortran function of the same name. xa is a list or an
    array. ya is a list of floats or a list of arrays"""
    if len(xa)!=len(ya):
        raise general.PyARTSError("Input lists xa and ya must be the same length!")

    ysize=numpy.array(ya[0]).shape
    c=copy.deepcopy(ya)
    d=copy.deepcopy(ya)


    den=numpy.zeros(ysize,float)
    w=numpy.zeros(ysize,float)
    ns=0
    dif=numpy.abs(x-xa[0])
    n=len(xa)
    for i in range(n):
        dift=numpy.abs(x-xa[i])
        if (dift<dif):
            ns=i
            dif=dift

    y=copy.deepcopy(ya[ns])
    ns=ns-1
    for m in range(1,n):
        for i in range(n-m):
            ho=xa[i]-x
            hp=xa[i+m]-x
            w=c[i+1]-d[i]
            den=ho-hp
            if den==0.0:
                print ho,hp
                raise general.PyARTSError("error in polint")
            den=w/den
            d[i]=hp*den
            c[i]=ho*den
        if (2*(ns+1) < n-m):
            dy=copy.deepcopy(c[ns+1])
        else:
            dy=copy.deepcopy(d[ns])
            ns=ns-1
        y=y+dy
    return y,dy

def qromb(Integrand,a,b,EPS=0.0001,JMAX=20,K=5):
    """performs Romberg integration of order 2*K.  In this case Integrand is a
    """
    JMAXP=JMAX+1
    h=numpy.zeros(JMAX+2, float)
    ytest=Integrand(a)
    if type(ytest)==types.FloatType:
        Integral=0
        trapIntegral=0
        error=0
        function_is_float=1
    else:
        functionshape=ytest.shape
        function_rank=len(ytest.shape)
        Integral=numpy.zeros(functionshape,float)
        trapIntegral=numpy.zeros(functionshape,float)
        errornumpy=numpy.zeros(functionshape,float)
        function_is_float=0
    h[0]=1.0
    traplist=[]
    for j in range(JMAX):
        trapIntegral=trapzd(Integrand,a,b,j+1,trapIntegral)
        traplist.append(trapIntegral)
        if (j >= K):
            Integral,error=polint(h[j-K:j],traplist[j-K:j],0.0)
            if function_is_float:
                if (abs(error)<=EPS*abs(Integral)):
                    return Integral, error
            else:
                maxvalue=(abs(Integral)).max()
                if function_rank>1:
                    for i in range(1,function_rank):
                        maxvalue=(maxvalue).max()
                if not maxvalue==0:
                    conv_cond=(abs(error)<=EPS*abs(Integral))|(abs(Integral/maxvalue)<EPS)
                else:
                    #we have a zero valued function
                    conv_condnumpy=numpy.ones(Integral.shape)

                for i in range(function_rank):
                    conv_cond=(conv_cond).all()
                if conv_cond:
                    return Integral, error

        h[j+1]=0.25*h[j]
    print "Warning!!:Too many steps in routine qromb"
    if not function_is_float:
        print "converged elements"
        print (abs(error)<=EPS*abs(Integral))|(abs(Integral/maxvalue)<EPS)
        raise ConvError 
    return Integral, error



def trapzd(Integrand,a,b,n,last_answer):
    """!trapzd ala Numerical recipes.
    !This is a dogs breakfast but it seems to work. Notice that I have
    !had to drop the variable sized array Integrand because it has to
    !be declared as a common block."""
    ytest=Integrand(a)
    if type(ytest)==types.FloatType:
        fx=0
        fA=0
        fB=0
        sum=0
        answer=0
    else:
        functionshape=ytest.shape
        fx=numpy.zeros(functionshape,float)
        fA=numpy.zeros(functionshape,float)
        fB=numpy.zeros(functionshape,float)
        sum=numpy.zeros(functionshape,float)
        answer=numpy.zeros(functionshape,float)
    if (n==1):
        fA=Integrand(a)
        fB=Integrand(b)
        answer=0.5*(b-a)*(fA+fB)

    else:
        it = 2**(n-2)
        tnm = float(it)
        d = (b-a)/tnm
        x=a+0.5*d
        sum=0.0
        for i in range(it):
            fx=Integrand(x)
            sum=sum+fx
            x=x+d
        answer=0.5*(last_answer+d*sum)

    return answer

def vanilla_mc(func,rangelist,N):
    """Plain Monte Carlo multi-dimensional integration using uniform probablilty density functions over the ranges specified by rangelist.  The function func must accept the same number of arguments as len(rangelist), where rangelist is a list of lists(tuples or arrays)"""
    import RNG
    rngs=[]
    V=1
    for r in rangelist:
        rngs.append(RNG.CreateGenerator(0,RNG.UniformDistribution(r[0],r[1])))
        V*=r[1]-r[0]
    Sum=0
    SumSquared=0
    for i in range(N):
        arglist=[]
        for rng in rngs:
            arglist.append(rng.ranf())
        f=func(*arglist)
        Sum+=f
        SumSquared+=f**2
    integral=V*Sum/N
    error=V*sqrt((SumSquared/N-(Sum/N)**2)/N)
    return integral,error




def integrate_quantity(M, quantity):
    """Integrates quantity, e.g. from Chevalier data.

    Needs:
        - nd-array with fields "z" and quantity. For example, this can be
          Chevalier data as returned by io.read_chevalier.
        - quantity to be integrated (does not do sanity check)
    """

    dz = M[:, 1:]["z"] - M[:, :-1]["z"]
    # layer-average
    y_avg = (M[:, 1:][quantity] + M[:, :-1][quantity])/2
    return (y_avg * dz).sum(1)

def layer2level(z, q, ignore_negative=False):
    """Converts layer to level. First dim. must be height.
    """
    dz = z[1:, ...] - z[:-1, ...]
    if ignore_negative:
        q[q<0]=0
    y_avg = (q[1:, ...] + q[:-1, ...])/2
    return (y_avg * dz)

def integrate_with_height(z, q, ignore_negative=False):
    return layer2level(z, q, ignore_negative).sum(0)

def cum_integrate_with_height(z, q, ignore_negative=False):
    return layer2level(z, q, ignore_negative).cumsum(0)

def combine_axes(*args):
    """Combines n axes to p x n matrix, as in:

    [0, 1, 2], [0, 1, 2, 3] --> [0 0; 0 1; 0 2; 0 3; 1 1; ...]
    To prepare for format of scipy.interpolate.griddata
    """
    axes = []
    if len(args) == 0:
        return numpy.array([])
    if len(args) == 1:
        return args[0]
    for i, ax in enumerate(args):
        newshape = [1]*len(args)
        newshape[i] = args[i].size
        axes.append(args[i].reshape(tuple(newshape)))

    S = sum(axes) # NB the builtin sum, not the numpy sum!
    combis = numpy.concatenate([(S-sum([inax for inax in axes if ax is not inax]))[..., None] for ax in axes], len(axes))
    return combis.reshape(numpy.product([ax.size for ax in axes]), len(axes))


def integrate_phasemat(angles, Z):
    """Integrates the phase-matrix over all angles

    Parameters
    ~~~~~~~~~~

    angles : 1D-array
        Angles to integrate over [degrees]

    Z : nD-array
        quantity to integrate. The last dimension must match the size of
        angles.

    Returns
    ~~~~~~~

    Z_int : (n-1)D-array
        Array with integrated Z. Has all but the last dimension of Z.
    """

    angles = numpy.deg2rad(angles)
    integrand = 2 * numpy.pi * Z * numpy.sin(angles)
    return scipy.integrate.trapz(integrand, angles)

def get_value_at_cloud_top(M, field, c="CIW"):
    """From Chev-like record array, get value (z, t, ...) at top of cloud.

    Uses CIW by default, but pass c=CLW for using liquid.
    """
    rv = numpy.zeros(M.shape[0])
    rv.fill(numpy.nan)
    nz = M[c].nonzero()
    (i, ui) = numpy.unique(nz[0], return_index=True)
    rv[i] = M[field][nz[0][ui], nz[1][ui]]
    return numpy.ma.masked_array(rv, numpy.isnan(rv))

def get_cth(M, f="CIW"):
    """Calculates CTH (based on CIW by default) from chev-like record array.

    Returns a masked array of the same size as M.
    """

    return get_value_at_cloud_top(M, "z", f)

def get_near_top(z, iwc, y, data=None):
    """For a iwc vector (or matrix), find level down to which integrated value is y
    
    Sizes of z and iwc must match. First dimension must be height.
    'y' must be scalar.

    Returns a masked array. Columns where 'y' is not reached are masked.

    If 'data' is provided, returns data at near-cloud-top. Otherwise,
    returns indices.
    """
    if not (z[1:]>z[:-1]).all():
        raise PyARTSError("z must be monotonically increasing")
    cum_iwp = cum_integrate_with_height(z[::-1, ...], iwc[::-1, ...], True)
    mask = (-cum_iwp)>=y
    # find first non-zero per row. 'mask' is a step-function so XOR works,
    # just need to prepend with row of zeroes to make indexing work
    first_nz = mask[:-1, ...]^mask[1:, ...]
    row_of_zeroes = numpy.zeros((1,) + first_nz.shape[1:],
                                dtype=numpy.bool_)
    ind = numpy.concatenate((row_of_zeroes, first_nz))
    #has_y = ind.any(0)
    #edge = numpy.zeros(z.shape[1:], dtype=numpy.int32)
    #edge.fill(numpy.iinfo(edge[0]).min)
    #edge[has_y] = (numpy.arange(ind.shape[0])[:, numpy.newaxis]*ind)
    firsts = (numpy.arange(ind.shape[0])[:, numpy.newaxis]*ind).sum(0)
    firsts[firsts==0] = numpy.iinfo(firsts[0]).min
    firsts = z.shape[0] - numpy.ma.masked_array(firsts, firsts<0)
    if data is not None:
        flatind = (firsts + z.shape[0]*numpy.arange(z.shape[1]))
        rv = numpy.ma.masked_array(numpy.zeros(z.shape[1:]), mask=flatind.mask)
        rv[~rv.mask] = data.T.flat[flatind[~flatind.mask]]
        return rv
    else:
        return firsts
    #return edge

def get_equal_content_bins(data, N):
    """Return N bins dividing data equally

    Note: if data are discrete, this is impossible.
    E.g. [1, 1, 1, 1, 2] divided over 2 bins will inevitably have 4 values
    in bin 1 and 1 in bin 2. This is visible e.g. for collocations.
    """

    ps = numpy.linspace(0, 100, N)
    perc = scipy.stats.scoreatpercentile
    edges = [perc(data, p) for p in ps]
    # cut off the first one as it's by definition 0
    return numpy.asarray(edges[1:])

def linreg_with_error(x, y, sigma):
    """Perform linear regression with errors in y

    Input: x, y, sigma
    Output:
    """

    # based on Garcia (2000), Numerical Methods for Physics, Listing 5A.1
    # (page 171)

    sigmaTerm = sigma ** -2
    s = sigmaTerm.sum()
    sx = (x * sigmaTerm).sum()
    sy = (y * sigmaTerm).sum()
    sxy = (x * y * sigmaTerm).sum()
    sxx = (x**2 * sigmaTerm).sum()
    denom = s*sxx - sx**2

    a_fit = numpy.array([
            (s * sxy - sx * sy)/denom,
            (sxx * sy - sx * sxy)/denom,
            ])

    sig_a = numpy.array([
            numpy.sqrt(s/denom),
            numpy.sqrt(sxx/denom),
            ])

    return (a_fit, sig_a)
