diff --git a/pylot/core/util/pdf.py b/pylot/core/util/pdf.py index 9518ba0f..e0e3f078 100644 --- a/pylot/core/util/pdf.py +++ b/pylot/core/util/pdf.py @@ -3,9 +3,8 @@ import warnings import numpy as np -import scipy.optimize from obspy import UTCDateTime -from pylot.core.util.utils import find_nearest, clims +from pylot.core.util.utils import fit_curve, find_nearest, clims from pylot.core.util.version import get_git_version as _getVersionString __version__ = _getVersionString() @@ -57,7 +56,7 @@ def exp_parameter(te, tm, tl, eta): return tm, sig1, sig2, a -def gauss_branches(k, mu, sig1, sig2, a1, a2): +def gauss_branches(k, (mu, sig1, sig2, a1, a2)): ''' function gauss_branches takes an axes x, a center value mu, two sigma values sig1 and sig2 and two scaling factors a1 and a2 and return a @@ -91,7 +90,7 @@ def gauss_branches(k, mu, sig1, sig2, a1, a2): return _func(k, mu, sig1, sig2, a1, a2) -def exp_branches(k, mu, sig1, sig2, a): +def exp_branches(k, (mu, sig1, sig2, a)): ''' function exp_branches takes an axes x, a center value mu, two sigma values sig1 and sig2 and a scaling factor a and return a @@ -130,17 +129,22 @@ class ProbabilityDensityFunction(object): version = __version__ - def __init__(self, x0, incr, npts, pdf, mu, params): + def __init__(self, x0, incr, npts, pdf, mu, params, eta=0.01): self.x0 = x0 self.incr = incr self.npts = npts self.axis = create_axis(x0, incr, npts) self.mu = mu + self.eta = eta self._pdf = pdf self.params = params def __add__(self, other): + assert self.eta == other.eta, 'decline factors differ please use equally defined pdfs for comparison' + + eta = self.eta + x0, incr, npts = self.commonparameter(other) axis = create_axis(x0, incr, npts) @@ -153,16 +157,19 @@ class ProbabilityDensityFunction(object): npts = pdf.size x0 *= 2 axis = create_axis(x0, incr, npts) + mu = axis[np.where(pdf == max(pdf))][0] - params, pcov = scipy.optimize.curve_fit(branches['gauss'], axis, pdf) + func, params = fit_curve(axis, pdf) - mu = axis[np.where(pdf == max(pdf))] - - return ProbabilityDensityFunction(x0, incr, npts, branches['gauss'], mu, - params) + return ProbabilityDensityFunction(x0, incr, npts, func, mu, + params, eta) def __sub__(self, other): + assert self.eta == other.eta, 'decline factors differ please use equally defined pdfs for comparison' + + eta = self.eta + x0, incr, npts = self.commonparameter(other) axis = create_axis(x0, incr, npts) @@ -176,16 +183,12 @@ class ProbabilityDensityFunction(object): midpoint = npts / 2 x0 = -incr * midpoint axis = create_axis(x0, incr, npts) - mu = axis[np.where(pdf == max(pdf))][0] - bounds = ([mu, 0., 0., 0., 0.],[mu, np.inf, np.inf, np.inf, np.inf]) + func, params = fit_curve(axis, pdf) - params, pcov = scipy.optimize.curve_fit(branches['gauss'], axis, pdf, - bounds=bounds) - - return ProbabilityDensityFunction(x0, incr, npts, branches['gauss'], mu, - params) + return ProbabilityDensityFunction(x0, incr, npts, func, mu, + params, eta) def __nonzero__(self): prec = self.precision(self.incr) @@ -204,7 +207,15 @@ class ProbabilityDensityFunction(object): return prec if prec >= 0 else 0 def data(self, value): - return self._pdf(value, *self.params) + return self._pdf(value, self.params) + + @property + def eta(self): + return self._eta + + @eta.setter + def eta(self, value): + self._eta = value @property def mu(self): @@ -271,7 +282,7 @@ class ProbabilityDensityFunction(object): # return the object return ProbabilityDensityFunction(x0, incr, npts, pdf, barycentre, - params) + params, decfact) def broadcast(self, pdf, si, ei, data): try: @@ -292,16 +303,15 @@ class ProbabilityDensityFunction(object): ''' rval = 0 - axis = self.axis - self.x0 - for n, x in enumerate(axis): - rval += x * self.data(n) - return rval * self.incr + self.x0 + for x in self.axis: + rval += x * self.data(x) + return rval * self.incr def standard_deviation(self): mu = self.mu rval = 0 - for n, x in enumerate(self.axis): - rval += (x - mu) ** 2 * self.data(n) + for x in self.axis: + rval += (x - mu) ** 2 * self.data(x) return rval * self.incr def prob_lt_val(self, value): diff --git a/pylot/core/util/utils.py b/pylot/core/util/utils.py index 521c548b..b24efb4d 100644 --- a/pylot/core/util/utils.py +++ b/pylot/core/util/utils.py @@ -3,6 +3,7 @@ import hashlib import numpy as np +from scipy.interpolate import splrep, splev import os import pwd import re @@ -16,6 +17,17 @@ def _pickle_method(m): else: return getattr, (m.im_self, m.im_func.func_name) +def fit_curve(x, y): + + return splev, splrep(x, y) + +def getindexbounds(f, eta): + mi = f.argmax() + m = max(f) + b = m * eta + l = find_nearest(f[:mi], b) + u = find_nearest(f[mi:], b) + mi + return mi, l, u def worker(func, input, cores='max', async=False): import multiprocessing