diff --git a/pylot/core/util/pdf.py b/pylot/core/util/pdf.py index c130b617..74c07360 100644 --- a/pylot/core/util/pdf.py +++ b/pylot/core/util/pdf.py @@ -109,6 +109,50 @@ class ProbabilityDensityFunction(object): version = __version__ + def __init__(self, x, pdf): + self.axis = x + self.data = pdf + + def __add__(self, other): + assert isinstance(other, ProbabilityDensityFunction), \ + 'both operands must be of type ProbabilityDensityFunction' + if self.sampling_rate() == other.sampling_rate(): + max = np.maximum(self.axis, other.axis) + x = np.arange(-max, max + self.sampling_rate(), self.sampling_rate()) + else: + raise ValueError('Sampling rates do not match!') + + pdf1 = self.data + + return ProbabilityDensityFunction(x, pdf) + + def __sub__(self, other): + pass + + def __nonzero__(self): + return True + + @property + def data(self): + return self.data + + @data.setter + def data(self, pdf): + self.data = np.array(pdf) + + @property + def axis(self): + return self.axis + + @axis.setter + def axis(self, x): + self.axis = np.array(x) + + def sampling_rate(self): + return self.axis[1] - self.axis[0] + +class PickPDF(ProbabilityDensityFunction): + def __init__(self, x, lbound, midpoint, rbound, decfact=0.01, type='gauss'): ''' Initialize a new ProbabilityDensityFunction object. Takes arguments x, @@ -126,15 +170,9 @@ class ProbabilityDensityFunction(object): branches ''' - self.axis = np.array(x) - self.nodes = dict(lbound=lbound, midpoint=midpoint, rbound=rbound, eta=decfact) + self.nodes = dict(te=lbound, tm=midpoint, tl=rbound, eta=decfact) self.type = type - - def __add__(self, other): - pass - - def __sub__(self, other): - pass + super(PickPDF, self).__init__(x, self.pdf()) @property def type(self): @@ -147,5 +185,8 @@ class ProbabilityDensityFunction(object): def params(self): return parameter[self.type](**self.nodes) - def data(self): - return branches[self.type](self.axis, self.nodes['midpoint'], *self.params()) + def get(self, key): + return self.nodes[key] + + def pdf(self): + return branches[self.type](self.axis, self.get('tm'), *self.params())