[edit] probability density function superclass implemented due to the different character of these functions

This commit is contained in:
Sebastian Wehling-Benatelli 2016-02-06 09:04:50 +01:00
parent ada9f4e780
commit 303a5f9cf0

View File

@ -109,6 +109,50 @@ class ProbabilityDensityFunction(object):
version = __version__ version = __version__
def __init__(self, x, pdf):
self.axis = x
self.data = pdf
def __add__(self, other):
assert isinstance(other, ProbabilityDensityFunction), \
'both operands must be of type ProbabilityDensityFunction'
if self.sampling_rate() == other.sampling_rate():
max = np.maximum(self.axis, other.axis)
x = np.arange(-max, max + self.sampling_rate(), self.sampling_rate())
else:
raise ValueError('Sampling rates do not match!')
pdf1 = self.data
return ProbabilityDensityFunction(x, pdf)
def __sub__(self, other):
pass
def __nonzero__(self):
return True
@property
def data(self):
return self.data
@data.setter
def data(self, pdf):
self.data = np.array(pdf)
@property
def axis(self):
return self.axis
@axis.setter
def axis(self, x):
self.axis = np.array(x)
def sampling_rate(self):
return self.axis[1] - self.axis[0]
class PickPDF(ProbabilityDensityFunction):
def __init__(self, x, lbound, midpoint, rbound, decfact=0.01, type='gauss'): def __init__(self, x, lbound, midpoint, rbound, decfact=0.01, type='gauss'):
''' '''
Initialize a new ProbabilityDensityFunction object. Takes arguments x, Initialize a new ProbabilityDensityFunction object. Takes arguments x,
@ -126,15 +170,9 @@ class ProbabilityDensityFunction(object):
branches branches
''' '''
self.axis = np.array(x) self.nodes = dict(te=lbound, tm=midpoint, tl=rbound, eta=decfact)
self.nodes = dict(lbound=lbound, midpoint=midpoint, rbound=rbound, eta=decfact)
self.type = type self.type = type
super(PickPDF, self).__init__(x, self.pdf())
def __add__(self, other):
pass
def __sub__(self, other):
pass
@property @property
def type(self): def type(self):
@ -147,5 +185,8 @@ class ProbabilityDensityFunction(object):
def params(self): def params(self):
return parameter[self.type](**self.nodes) return parameter[self.type](**self.nodes)
def data(self): def get(self, key):
return branches[self.type](self.axis, self.nodes['midpoint'], *self.params()) return self.nodes[key]
def pdf(self):
return branches[self.type](self.axis, self.get('tm'), *self.params())