"""General purpose plotting functions using the matplotlib package."""

from scipy import *
from pylab import *
from matplotlib.colors import *
from .arts_geometry import EARTH_RADIUS
import arts_math
import artsXML
from types import IntType, FloatType, ListType

import numpy


#Plotting functions

def drawCloudBox(zbase,ztop,lat1,lat2,npts=40,format='k'):
    """draws a cloudbox cross section"""
    r_geoid=EARTH_RADIUS
    rt=r_geoid/1e3+ztop
    rb=r_geoid/1e3+zbase
    lat=arts_math.nlinspace(lat1,lat2,npts)
    xb=rb*sin(pi*lat/180)
    yb=rb*cos(pi*lat/180)
    xt=rt*sin(pi*lat/180)
    yt=rt*cos(pi*lat/180)
    xall=zeros(len(xb)*2+1,float)
    yall=zeros(len(xb)*2+1,float)
    xall[:len(xb)]=xb
    yall[:len(yb)]=yb
    xall[len(xb):-1]=squeeze(fliplr(resize(xt,[1,len(xt)])))
    xall[-1]=xb[0]
    yall[len(yb):-1]=squeeze(fliplr(resize(yt,[1,len(xt)])))
    yall[-1]=yb[0]
    return plot(xall,yall,format)

def drawPpath(filename,format='k'):
    """plots a propagation path from an ARTS XML file in x,y (km) coordinates"""
    r_geoid=EARTH_RADIUS
    ppath=artsXML.load(filename)
    pathz=ppath['GeometricalAltitudes']
    pathlat=ppath['PropagationPathPointPositions'][:,1]
    r=pathz+r_geoid
    x=r*sin(pi*pathlat/180)
    y=r*cos(pi*pathlat/180)
    return plot(x/1e3,y/1e3,format)

def drawSurface(lat1,lat2,npts=40,format='k'):
    """draws the geoid surface"""
    r_geoid=EARTH_RADIUS
    lat=arts_math.nlinspace(lat1,lat2,npts)
    x=r_geoid/1e3*sin(pi*lat/180)
    y=r_geoid/1e3*cos(pi*lat/180)
    return plot(x,y,format)

def hotcoldmap(zmin,zmax):
    """produces a color map with a black-blue-green scale for values below zero,
    and a black-red-yellow scale for values above zero"""
    if zmin>zmax:
        raise ValueError, 'zmax must be larger than zmin'
    if zmax < 0:
        segmentdata={'red':[(0,0,0),(1,0,0)],
                     'green':[(0,1,1),(0.5,0,0),(1,0,0)],
                     'blue':[(0,1,1),(0.5,1,1),(1,0,0)]}
    elif zmin > 0:
        segmentdata={'red':[(0,0,0),(0.5,1,1),(1,1,1)],
                     'green':[(0,0,0),(0.5,0,0),(1,1,1)],
                     'blue':[(0,0,0),(1,0,0)]}
    else:
        xzero=min([1,abs(zmin)/(zmax-zmin)])
        dzblack=maximum(abs(zmin),abs(zmax))/20
        if zmin > -dzblack:
            segmentdata={'red':[(0,0,0),(xzero+0.05,0,0),
                                ((xzero+0.05+1)/2,1,1),(1,1,1)],
                         'green':[(0,0,0),((xzero+0.05+1)/2,0,0),(1,1,1)],
                         'blue':[(0,0,0),(1,0,0)]}
        elif zmax<dzblack:
            segmentdata={'red':[(0,0,0),(1,0,0)],
                         'green':[(0,1,1),((xzero-0.05)/2,0,0),(1,0,0)],
                         'blue':[(0,1,1),((xzero-0.05)/2,1,1),(xzero-0.05,0,0),(1,0,0)]}
        else:
            segmentdata={'red':[(0,0,0),(xzero+0.05,0,0),((xzero+0.05+1)/2,1,1),(1,1,1)],
                         'green':[(0,1,1),((xzero-0.05)/2,0,0),((xzero+0.05+1)/2,0,0),
                                  (1,1,1)],
                         'blue':[(0,1,1),((xzero-0.05)/2,1,1),(xzero-0.05,0,0),(1,0,0)]}
    return LinearSegmentedColormap('hotcold',segmentdata)

#class SentinelNorm(normalize):
#    """
#    For use with SentinelMap (see SentinalMap docstring). Leave the sentinel unchanged
#    """
#    def __init__(self, sentinel, vmin=None, vmax=None, clip=True):
#        normalize.__init__(self, vmin=vmin, vmax=vmax, clip=clip)
#        self.sentinel = sentinel
#        
#
#    def __call__(self, value):
#        vnorm = normalize.__call__(self, value)
#        return where(value==self.sentinel, self.sentinel, vnorm)

#class SentinelMap(Colormap):
#    """SentinelMap is a matplotlib colormap that deals with data points that you want
#    to distinguish from the rest of the data. For example if bad data is stored as -999,
#    these values are plotted a specified rgb color, and the rest of the colormap (cmap)
#    is unchanged.  This needs to be used with the SentinelNorm class. e.g.#
#
#    >>> cmap = SentinelMap(cm.jet, -999, (0,0,0))
#    >>> norm = SentinelNorm(-999)
#    >>> pcolor(x,y,z,norm=norm,cmap=cmap)
#
#    will plot the data with the usual jet colormap but with bad data values (-999) black. 
#
#    """
#    def __init__(self, cmap, sentinel, rgb):
#        self.N = cmap.N
#        self.name = 'SentinelMap'
#        self.cmap = cmap
#        self.sentinel = sentinel
#        self.monochrome=False
#        if len(rgb)!=3:
#            raise ValueError('sentinel color must be RGB')
#        
#        self.rgb = rgb
#        
#    def __call__(self, X, alpha=1):##
#
#        
#        retshape=concatenate((X.shape,(4,)))
#        
#        r,g,b = self.rgb
#        Xm = self.cmap(X)#
#
#        ret = zeros( retshape, typecode=float)#
#
##       if len(X.shape)==2:
#            ret[:,:,0] =  where(X==self.sentinel, r, Xm[:,:,0])
#            ret[:,:,1] =  where(X==self.sentinel, g, Xm[:,:,1])
#            ret[:,:,2] =  where(X==self.sentinel, b, Xm[:,:,2])
#            ret[:,:,3] =  where(X==self.sentinel, alpha, Xm[:,:,3])
#        else:
#            
#            ret[:,0] =  where(X==self.sentinel, r, Xm[:,0])
#            ret[:,1] =  where(X==self.sentinel, g, Xm[:,1])
#            ret[:,2] =  where(X==self.sentinel, b, Xm[:,2])
#            ret[:,3] =  where(X==self.sentinel, alpha, Xm[:,3])
#        return ret



from matplotlib.colors import Colormap, normalize
#import matplotlib.numerix as nx
import numpy as nx
from types import IntType, FloatType, ListType

class SentinelMap(Colormap):
    def __init__(self, cmap, sentinels={}):
        # boilerplate stuff
        self.N = cmap.N
        self.name = 'SentinelMap'
        self.cmap = cmap
        self.sentinels = sentinels
        self.monochrome=False
        for rgb in sentinels.values():
            if len(rgb)!=3:
                raise ValueError('sentinel color must be RGB')


    def __call__(self, scaledImageData, alpha=1):
        # assumes the data is already normalized (ignoring sentinels)
        # clip to be on the safe side
        rgbaValues = self.cmap(nx.clip(scaledImageData, 0.,1.))

        #replace sentinel data with sentinel colors
        for sentinel,rgb in self.sentinels.items():
            r,g,b = rgb
            if len(scaledImageData.shape)==2:
                rgbaValues[:,:,0] =  nx.where(scaledImageData==sentinel, r, rgbaValues[:,:,0])
                rgbaValues[:,:,1] =  nx.where(scaledImageData==sentinel, g, rgbaValues[:,:,1])
                rgbaValues[:,:,2] =  nx.where(scaledImageData==sentinel, b, rgbaValues[:,:,2])
                rgbaValues[:,:,3] =  nx.where(scaledImageData==sentinel, alpha, rgbaValues[:,:,3])
            else:
                rgbaValues[:,0] =  nx.where(scaledImageData==sentinel, r, rgbaValues[:,0])
                rgbaValues[:,1] =  nx.where(scaledImageData==sentinel, g, rgbaValues[:,1])
                rgbaValues[:,2] =  nx.where(scaledImageData==sentinel, b, rgbaValues[:,2])
                rgbaValues[:,3] =  nx.where(scaledImageData==sentinel, alpha, rgbaValues[:,3])

        return rgbaValues



class SentinelNorm(normalize):
    """
    Leave the sentinel unchanged
    """
    def __init__(self, ignore=[], vmin=None, vmax=None):
        self.vmin=vmin
        self.vmax=vmax

        if type(ignore) in [IntType, FloatType]:
            self.ignore = [ignore]
        else:
            self.ignore = list(ignore)


    def __call__(self, value):

        vmin = self.vmin
        vmax = self.vmax

        if type(value) in [IntType, FloatType]:
            vtype = 'scalar'
            val = array([value])
        else:
            vtype = 'array'
            val = nx.asarray(value)

        # if both vmin is None and vmax is None, we'll automatically
        # norm the data to vmin/vmax of the actual data, so the
        # clipping step won't be needed.
        if vmin is None and vmax is None:
            needs_clipping = False
        else:
            needs_clipping = True

        if vmin is None or vmax is None:
            rval = nx.ravel(val)
            #do this if sentinels (values to ignore in data)
            if self.ignore:
                sortValues=nx.sort(rval)
                if vmin is None:
                    # find the lowest non-sentinel value
                    for thisVal in sortValues:
                        if thisVal not in self.ignore:
                            vmin=thisVal #vmin is the lowest non-sentinel value
                            break
                    else:
                        vmin=0.
                if vmax is None:
                    for thisVal in sortValues[::-1]:
                        if thisVal not in self.ignore:
                            vmax=thisVal #vmax is the greatest non-sentinel value
                            break
                    else:
                        vmax=0.
            else:
                if vmin is None: vmin = min(rval)
                if vmax is None: vmax = max(rval)
        if vmin > vmax:
            raise ValueError("minvalue must be less than or equal to maxvalue")
        elif vmin==vmax:
            return 0.*value
        else:
            if needs_clipping:
                val = nx.clip(val,vmin, vmax)
            result = (1.0/(vmax-vmin))*(val-vmin)

        # replace sentinels with original (non-normalized) values
        for thisIgnore in self.ignore:
            result = nx.where(val==thisIgnore,thisIgnore,result)

        if vtype == 'scalar':
            result = result[0]
        return result


def myPcolor(x,y,z,**kwargs):
    """With the matplotlib pcolor you actually lose the last row and
    column of data.  This function addresses this, and produces a pcolor
    plot where the patches are centred on the x and y values"""
    xnew=zeros(array(x.shape)+1,float)
    ynew=zeros(array(x.shape)+1,float)
    znew=zeros(array(x.shape)+1,float)
    xnew[:-1,:-1]=x
    ynew[:-1,:-1]=y
    znew[:-1,:-1]=z
    xnew[:,-1]=2*xnew[:,-2]-xnew[:,-3]
    xnew[-1,:]=xnew[-2,:]
    ynew[-1,:]=2*ynew[-2,:]-ynew[-3,:]
    ynew[:,-1]=ynew[:,-2]
    dx=zeros(xnew.shape,float)
    dy=zeros(ynew.shape,float)
    dx[:,:-1]=(xnew[:,1:]-xnew[:,:-1])/2
    dx[:,-1]=dx[:,-2]
    dy[:-1,:]=(ynew[1:,:]-ynew[:-1,:])/2
    dy[-1,:]=dy[-2,:]
    xfinal=xnew-dx
    yfinal=ynew-dy
    return pcolor(xfinal,yfinal,znew,**kwargs)

def mySubplot(nrows,ncols,pnum,figpos=[0.05,0.05,0.9,0.9],
              axpos=[0.15,0.15,0.75,0.75]):
    """More like the matlab subplot than matplotlib. Divides a portion of the
    current figure, determined by *figpos* in to *nrows* by *ncolumns* panels.
    The normalised position of the axes within the panel is given by *axpos*.
    The axes object is returned."""

    panelwidth=figpos[2]/ncols
    panelheight=figpos[3]/nrows
    pos=[figpos[0]+(remainder(pnum-1,ncols))*panelwidth+axpos[0]*panelwidth,
         figpos[1]+(nrows-1-(pnum-1)/ncols)*panelheight+axpos[1]*panelheight,
         panelwidth*axpos[2],
         panelheight*axpos[3]]
    return axes(pos)

def setDataAspectRatioByFigSize(ax,r):
    """Same idea as matlab.  Adjusts the figure size to fix the aspect ratio"""
    xlim=ax.get_xlim()
    dxlim=xlim[1]-xlim[0]
    ylim=ax.get_ylim()
    dylim=ylim[1]-ylim[0]
    pos=ax.get_position()
    dxpos=pos[2]
    dypos=pos[3]
    f=gcf()
    figsize=f.get_size_inches()
    dxfig=figsize[0]
    dyfig=dylim*dxpos*dxfig/r/dypos/dxlim
    f.set_figsize_inches(dxfig,dyfig)

def setDataAspectRatioByAxisPos(ax,r):
    """Same idea as matlab.  Adjusts the axis position to fix the aspect ratio"""
    xlim=ax.get_xlim()
    dxlim=xlim[1]-xlim[0]
    ylim=ax.get_ylim()
    dylim=ylim[1]-ylim[0]
    pos=ax.get_position()
    dxpos=pos[2]
    dypos=pos[3]
    centreypos=pos[1]+0.5*dypos
    f=gcf()
    figsize=f.get_size_inches()
    dxfig=figsize[0]
    dyfig=figsize[1]
    dypos=dylim*dxpos*dxfig/r/dyfig/dxlim
    ax.set_position([pos[0],centreypos-0.5*dypos,dxpos,dypos])


def shiftaxes(ax,delta_pos):
    """For use with matplotlib.
    **input:**
           *ax*, a matplotlib.axes object;
           *delta_pos*, a 4 element list or array correspond to 
           [delta_x_start,delta_y_start,delta_width,delta_height] in
           normalised units.
    """
    ax.set_position(array(ax.get_position())+delta_pos)

def niceticks(xmin,xmax,nmax):
    """For a given xmin and xmax, niceticks tries to give no more than
    nmax nice tick values.  This might not be needed so much anymore""" 
    om=int(floor(log10((xmax-xmin)/nmax)))
    a=(xmax-xmin)/nmax/10**om
    allowed=[1,2,5,10]
    i=arts_math.locate(allowed,a)
    deltax=allowed[i+1]*10**om
    #catch i=2 om = -1
    if deltax>=1:
        deltax=int(deltax)
    if deltax<1:
        typecode=float
    else:
        typecode=Int
    ticks=arange(int(ceil(float(xmin)/deltax))*deltax,xmax,deltax,typecode)
    #catch almost zero floats
    if deltax<1:
        for i in range(len(ticks)):
            if abs(ticks[i]/deltax)<0.5:
                ticks[i]=0.0
    return ticks

def linehist(values, edges):
    (count, edges) = numpy.histogram(values, edges)
    lh = numpy.c_[edges[:-1], count]
    rh = numpy.c_[edges[1:], count]
    M = numpy.c_[lh, rh].reshape(-1, 2)
    return (M[:, 0], numpy.uint64(M[:, 1]))

