# -------------------------------------------------------------------------
#     Copyright (C) 2005-2010 Martin Strohalm <www.mmass.org>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 3 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE.TXT in the
#     main directory of the program
# -------------------------------------------------------------------------

# load libs
import numpy
import math
from numpy.linalg import solve as solveLinEq


# load configuration
import config


# SPECTRUM PROCESSING
# -------------------

def noise(points, minX=None, maxX=None, mz=None, window=0.1):
    """Calculate noise for given points.
        points: (numpy.array) spectrum points
        minX, maxX: (float or None) points selection
        mz: (float or None) m/z value for which to calculate the noise +- window
        window: (float) points range for noise calculation in %/100, relative to given m/z
    """
    
    # use sub-portion of the data
    if mz != None:
        window = mz*window
        i1 = _getIndex(points, mz-window)
        i2 = _getIndex(points, mz+window)
        points = points[i1:i2]
    elif minX != None and maxX != None:
        i1 = _getIndex(points, minX)
        i2 = _getIndex(points, maxX)
        points = points[i1:i2]
    
    # check points
    if len(points) == 0:
        return None, None
    
    # unpack data
    x,y = numpy.hsplit(points,2)
    y = y.flatten()
    
    # get noise offset
    noiseLevel = numpy.median(y)
    
    # get noise width
    noiseWidth = numpy.median(numpy.absolute(y - noiseLevel))
    noiseWidth = float(noiseWidth)*2
    
    return noiseLevel, noiseWidth
# ----


def baseline(points, segments, offset=0., smooth=True):
    """Calculate baseline for given points.
        points: (numpy.array)
        segments: (int) number of baseline segments
        offset: (float) intensity offset in %/100
        smooth: (bool) smooth final baseline
    """
    
    base = []
    
    # get number of points per segment
    width = int(len(points)/segments)
    width = max(1, width)
    
    # unpack x and y values
    xAxis, yAxis = numpy.hsplit(points,2)
    xAxis = xAxis.flatten()
    yAxis = yAxis.flatten()
    
    # get first point
    segment = yAxis[0:width/2]
    med = float(numpy.median(segment))
    mad = numpy.median(numpy.absolute(segment - med))
    base.append([xAxis[0], med-mad])
    
    # calculate baseline as medians of each segments
    for i in xrange(0, len(points)-width, width):
        segment = yAxis[i:i+width]
        med = float(numpy.median(segment))
        mad = numpy.median(numpy.absolute(segment - med))
        x = xAxis[i] + (xAxis[i+width]-xAxis[i])/2
        base.append([x, med-mad])
    
    # get last point
    segment = yAxis[-width/2:-1]
    med = float(numpy.median(segment))
    mad = numpy.median(numpy.absolute(segment - med))
    base.append([xAxis[-1], med-mad])
    
    # convert to array
    base = numpy.array(base)
    
    # smooth baseline
    if smooth:
        windowSize = 5*(points[-1][0]-points[0][0])/segments
        base = smoothSG(base, windowSize, 2)
    
    # offset baseline
    base = base * numpy.array([1., 1.-offset])
    
    return base
# ----


def correctBaseline(points, segments, offset=0., smooth=True):
    """Subtract baseline from given points.
        points: (numpy.array)
        segments: (int) number of baseline segments
        offset: (float) intensity offset in %/100
        smooth: (bool) smooth final baseline
    """
    
    # get baseline points
    base = baseline(points, segments, offset, smooth)
    
    # set first baseline segment
    i = 0
    m = (base[i][1] - base[i-1][1])/(base[i][0] - base[i-1][0])
    b = base[i-1][1] - m * base[i-1][0]
    
    # calculate offsets
    offsets = []
    for x in xrange(len(points)):
        while base[i][0] < points[x][0]:
            i += 1
            m = (base[i][1] - base[i-1][1])/(base[i][0] - base[i-1][0])
            b = base[i-1][1] - m * base[i-1][0]
        offsets.append((0., m * points[x][0] + b))
    offsets = numpy.array(offsets)
    
    # shift points to zero level
    shifted = points - offsets
    
    # remove negative intensities
    minXY = numpy.minimum.reduce(shifted)
    maxXY = numpy.maximum.reduce(shifted)
    shifted = shifted.clip([minXY[0],0.],maxXY)
    
    return shifted
# ----


def smoothMA(points, windowSize, cycles=1):
    """Smooth points by moving average.
        points: (numpy.array) points to be smoothed
        windowSize: (float) m/z window size for smoothing
        cycles: (int) number of repeating cycles
    """
    
    # approximate number of points within windowSize
    windowSize = int(windowSize*len(points)/(points[-1][0]-points[0][0]))
    if windowSize < 2:
        return points
    if not windowSize % 2:
        windowSize += 1
    
    # unpack mz and intensity
    xAxis,yAxis = numpy.hsplit(points,2)
    xAxis = xAxis.flatten()
    yAxis = yAxis.flatten()
    
    # smooth the points
    while cycles:
        s=numpy.r_[2*yAxis[0]-yAxis[windowSize:1:-1],yAxis,2*yAxis[-1]-yAxis[-1:-windowSize:-1]]
        w=numpy.ones(windowSize,'f')
        y=numpy.convolve(w/w.sum(),s,mode='same')
        smoothData = y[windowSize-1:-windowSize+1]
        yAxis = smoothData
        cycles -=1
    
    # return smoothed scan
    return numpy.array(zip(xAxis, yAxis))
# ----


def smoothSG(points, windowSize, cycles=1, order=3):
    """Smoothe points by Savitzky-Golay filter.
        points: (numpy.array) points to be smoothed
        windowSize: (float) m/z window size for smoothing
        cycles: (int) number of repeating cycles
        order: (int) order of polynom used
    """
    
    # approximate number of points within windowSize
    windowSize = int(windowSize*len(points)/(points[-1][0]-points[0][0]))
    if windowSize <= order:
        return points
    
    # unpack axes
    xAxis,yAxis = numpy.hsplit(points,2)
    xAxis = xAxis.flatten()
    yAxis = yAxis.flatten()
    
    # coeficients
    orderRange = range(order+1)
    halfWindow = (windowSize-1) // 2
    b = numpy.mat([[k**i for i in orderRange] for k in range(-halfWindow, halfWindow+1)])
    m = numpy.linalg.pinv(b).A[0]
    windowSize = len(m)
    halfWindow = (windowSize-1) // 2
    
    # precompute the offset values for better performance
    offsets = range(-halfWindow, halfWindow+1)
    offsetData = zip(offsets, m)
    
    # smooth the data
    while cycles:
        smoothData = list()
        yAxis = numpy.concatenate((numpy.zeros(halfWindow)+yAxis[0], yAxis, numpy.zeros(halfWindow)+yAxis[-1]))
        for i in range(halfWindow, len(yAxis) - halfWindow):
            value = 0.0
            for offset, weight in offsetData:
                value += weight * yAxis[i + offset]
            smoothData.append(value)
        yAxis = smoothData
        cycles -=1
    
    # return smoothed data
    return numpy.array(zip(xAxis, yAxis))
# ----


def _getIndex(points, x):
    """Get nearest index for selected point."""
    
    lo = 0
    hi = len(points)
    while lo < hi:
        mid = (lo + hi) / 2
        if x < points[mid][0]:
            hi = mid
        else:
            lo = mid + 1
        
    return lo
# ----



# DATA RE-CALIBRATION
# -------------------

def calibration(data, model='linear'):
    """Calculate calibration constants for given references.
        data: (list) pairs of (measured mass, reference mass)
        model: ('linear' or 'quadratic')
        This function uses least square fitting written by Konrad Hinsen.
    """
    
    # single point calibration
    if model == 'linear' and len(data) == 1:
        shift = data[0][1] - data[0][0]
        return _linearModel, (1., shift), 1.0
    
    # set fitting model and initial values
    if model=='linear':
        model = _linearModel
        initials = (0.5, 0)
    elif model=='quadratic':
        model = _quadraticModel
        initials = (1., 0, 0)
    
    # calculate calibration constants
    params = _leastSquaresFit(model, initials, data)
    
    # fce, parameters, chi-square
    return model, params[0], params[1]
# ----


def _linearModel(params, x):
    """Function for linear model."""
    
    a, b = params
    return a*x + b
    


def _quadraticModel(params, x):
    """Function for quadratic model."""
    
    a, b, c = params
    return a*x*x + b*x + c
    


def _leastSquaresFit(model, parameters, data, max_iterations=None, stopping_limit = 0.005):
    """General non-linear least-squares fit using the
    Levenberg-Marquardt algorithm and automatic derivatives."""
    
    n_param = len(parameters)
    p = ()
    i = 0
    for param in parameters:
        p = p + (_DerivVar(param, i),)
        i = i + 1
    id = numpy.identity(n_param)
    l = 0.001
    chi_sq, alpha = _chiSquare(model, p, data)
    niter = 0
    while 1:
        delta = solveLinEq(alpha+l*numpy.diagonal(alpha)*id,-0.5*numpy.array(chi_sq[1]))
        next_p = map(lambda a,b: a+b, p, delta)
        next_chi_sq, next_alpha = _chiSquare(model, next_p, data)
        if next_chi_sq > chi_sq:
            l = 10.*l
        else:
            l = 0.1*l
            if chi_sq[0] - next_chi_sq[0] < stopping_limit: break
            p = next_p
            chi_sq = next_chi_sq
            alpha = next_alpha
        niter = niter + 1
        if max_iterations is not None and niter == max_iterations:
            pass
    return map(lambda p: p[0], next_p), next_chi_sq[0]


def _isDerivVar(x):
    """Returns 1 if |x| is a DerivVar object."""
    return hasattr(x,'value') and hasattr(x,'deriv')


def _chiSquare(model, parameters, data):
    """ Count Chi-square. """
    
    n_param = len(parameters)
    chi_sq = 0.
    alpha = numpy.zeros((n_param, n_param))
    for point in data:
        sigma = 1
        if len(point) == 3:
            sigma = point[2]
        f = model(parameters, point[0])
        chi_sq = chi_sq + ((f-point[1])/sigma)**2
        d = numpy.array(f[1])/sigma
        alpha = alpha + d[:,numpy.newaxis]*d
    return chi_sq, alpha


def _mapderiv(func, a, b):
    """ Map a binary function on two first derivative lists. """
    
    nvars = max(len(a), len(b))
    a = a + (nvars-len(a))*[0]
    b = b + (nvars-len(b))*[0]
    return map(func, a, b)


class _DerivVar:
    """This module provides automatic differentiation for functions with any number of variables."""
    
    def __init__(self, value, index=0, order=1):
        if order > 1:
            raise ValueError, 'Only first-order derivatives'
        self.value = value
        if order == 0:
            self.deriv = []
        elif type(index) == type([]):
            self.deriv = index
        else:
            self.deriv = index*[0] + [1]
    
    def __getitem__(self, item):
        if item < 0 or item > 1:
            raise ValueError, 'Index out of range'
        if item == 0:
            return self.value
        else:
            return self.deriv
    
    def __coerce__(self, other):
        if _isDerivVar(other):
            return self, other
        else:
            return self, _DerivVar(other, [])
    
    def __cmp__(self, other):
        return cmp(self.value, other.value)
    
    def __add__(self, other):
        return _DerivVar(self.value + other.value, _mapderiv(lambda a,b: a+b, self.deriv, other.deriv))
    __radd__ = __add__
    
    def __sub__(self, other):
        return _DerivVar(self.value - other.value, _mapderiv(lambda a,b: a-b, self.deriv, other.deriv))
    
    def __mul__(self, other):
        return _DerivVar(self.value*other.value,
            _mapderiv(lambda a,b: a+b,
                map(lambda x,f=other.value:f*x, self.deriv),
                map(lambda x,f=self.value:f*x, other.deriv)))
    
    __rmul__ = __mul__
    
    def __div__(self, other):
        if not other.value:
            raise ZeroDivisionError, 'DerivVar division'
        inv = 1./other.value
        return _DerivVar(self.value*inv,
            _mapderiv(lambda a,b: a-b,
                map(lambda x,f=inv: f*x, self.deriv),
                map(lambda x,f=self.value*inv*inv: f*x,
                    other.deriv)))
    
    def __rdiv__(self, other):
        return other/self
    
    def __pow__(self, other, z=None):
        if z is not None:
            raise TypeError, 'DerivVar does not support ternary pow()'
        val1 = pow(self.value, other.value-1)
        val = val1*self.value
        deriv1 = map(lambda x,f=val1*other.value: f*x, self.deriv)
        if _isDerivVar(other) and len(other.deriv) > 0:
            deriv2 = map(lambda x, f=val*numpy.log(self.value): f*x,
                             other.deriv)
            return _DerivVar(val,_mapderiv(lambda a,b: a+b, deriv1, deriv2))
        else:
            return _DerivVar(val,deriv1)
    
    def __rpow__(self, other):
        return pow(other, self)
    

