repick button for plot_traces

This commit is contained in:
Marcel Paffrath 2015-10-19 10:33:51 +02:00
parent d78b0f1cff
commit 4498c72c90

View File

@ -7,6 +7,8 @@ from pylot.core.pick.CharFuns import HOScf
from pylot.core.pick.CharFuns import AICcf
from pylot.core.pick.utils import getSNR
from pylot.core.pick.utils import earllatepicker
import matplotlib.pyplot as plt
plt.interactive('True')
class SeismicShot(object):
'''
@ -127,8 +129,15 @@ class SeismicShot(object):
def getSourcefile(self):
return self.paras['sourcefile']
def getPick(self, traceID):
return self.pick[traceID]['mpp']
def getPick(self, traceID, returnRemoved = False):
if not self.getFlag(traceID) == 0:
return self.pick[traceID]['mpp']
if returnRemoved == True:
#print('getPick: Returned removed pick for shot %d, traceID %d' %(self.getShotnumber(), traceID))
return self.pick[traceID]['mpp']
def getPickIncludeRemoved(self, traceID):
return self.getPick(traceID, returnRemoved = True)
def getEarliest(self, traceID):
return self.pick[traceID]['epp']
@ -136,6 +145,12 @@ class SeismicShot(object):
def getLatest(self, traceID):
return self.pick[traceID]['lpp']
def getSymmetricPickError(self, traceID):
pickerror = self.pick[traceID]['spe']
if np.isnan(pickerror) == True:
print "SPE is NaN for shot %s, traceID %s"%(self.getShotnumber(), traceID)
return pickerror
def getPickError(self, traceID):
pickerror = abs(self.getEarliest(traceID) - self.getLatest(traceID))
if np.isnan(pickerror) == True:
@ -209,7 +224,7 @@ class SeismicShot(object):
:type: int
'''
return HOScf(self.getSingleStream(traceID), self.getCut(),
self.getTmovwind(), self.getOrder())
self.getTmovwind(), self.getOrder(), stealthMode = True)
def getAICcf(self, traceID):
'''
@ -232,7 +247,7 @@ class SeismicShot(object):
tr_cf = Trace()
tr_cf.data = self.getHOScf(traceID).getCF()
st_cf += tr_cf
return AICcf(st_cf, self.getCut(), self.getTmovwind())
return AICcf(st_cf, self.getCut(), self.getTmovwind(), stealthMode = True)
def getSingleStream(self, traceID): ########## SEG2 / SEGY ? ##########
'''
@ -291,11 +306,13 @@ class SeismicShot(object):
def setEarllatepick(self, traceID, nfac = 1.5):
tgap = self.getTgap()
tsignal = self.getTsignal()
tnoise = self.getPick(traceID) - tgap
tnoise = self.getPickIncludeRemoved(traceID) - tgap
(self.pick[traceID]['epp'], self.pick[traceID]['lpp'], tmp) = earllatepicker(self.getSingleStream(traceID),
nfac, (tnoise, tgap, tsignal),
self.getPick(traceID))
(self.pick[traceID]['epp'], self.pick[traceID]['lpp'],
self.pick[traceID]['spe']) = earllatepicker(self.getSingleStream(traceID),
nfac, (tnoise, tgap, tsignal),
self.getPickIncludeRemoved(traceID),
stealthMode = True)
def threshold(self, hoscf, aiccf, windowsize, pickwindow, folm = 0.6):
'''
@ -463,7 +480,7 @@ class SeismicShot(object):
def setFlag(self, traceID, flag):
'Set flag = 0 if pick is invalid, else flag = 1'
self.pick[traceID]['flag'] = 0
self.pick[traceID]['flag'] = flag
def getFlag(self, traceID):
return self.pick[traceID]['flag']
@ -570,42 +587,89 @@ class SeismicShot(object):
# plt.plot(self.getDistArray4ttcPlot(), pickwindowarray_upperb, ':k')
def plot_traces(self, traceID, folm = 0.6): ########## 2D, muss noch mehr verbessert werden ##########
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
def onclick(event):
self.setPick(traceID, event.xdata)
self._drawStream(traceID, refresh = True)
self._drawCFs(traceID, folm, refresh = True)
fig.canvas.mpl_disconnect(self.traces4plot[traceID]['cid'])
plt.draw()
def connectButton(event = None):
cid = fig.canvas.mpl_connect('button_press_event', onclick)
self.traces4plot[traceID]['cid'] = cid
fig = plt.figure()
ax1 = fig.add_subplot(2,1,1)
ax2 = fig.add_subplot(2,1,2, sharex = ax1)
axb = fig.add_axes([0.15, 0.91, 0.05, 0.03])
button = Button(axb, 'repick', color = 'red', hovercolor = 'grey')
button.on_clicked(connectButton)
self.traces4plot = {}
if not traceID in self.traces4plot.keys():
self.traces4plot[traceID] = {'fig': fig,
'ax1': ax1,
'ax2': ax2,
'axb': axb,
'button': button,
'cid': None,}
self._drawStream(traceID)
self._drawCFs(traceID, folm)
def _drawStream(self, traceID, refresh = False):
from pylot.core.util.utils import getGlobalTimes
from pylot.core.util.utils import prepTimeAxis
stream = self.getSingleStream(traceID)
stime = getGlobalTimes(stream)[0]
timeaxis = prepTimeAxis(stime, stream[0])
timeaxis -= stime
plt.interactive('True')
ax = self.traces4plot[traceID]['ax1']
if refresh == True:
xlim, ylim = ax.get_xlim(), ax.get_ylim()
ax.clear()
if refresh == True:
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_title('Shot: %s, traceID: %s, pick: %s'
%(self.getShotnumber(), traceID, self.getPick(traceID)))
ax.plot(timeaxis, stream[0].data, 'k', label = 'trace')
ax.plot([self.getPick(traceID), self.getPick(traceID)],
[min(stream[0].data),
max(stream[0].data)],
'r', label = 'mostlikely')
ax.legend()
def _drawCFs(self, traceID, folm, refresh = False):
hoscf = self.getHOScf(traceID)
aiccf = self.getAICcf(traceID)
ax = self.traces4plot[traceID]['ax2']
fig = plt.figure()
ax1 = plt.subplot(2,1,1)
plt.title('Shot: %s, traceID: %s, pick: %s' %(self.getShotnumber(), traceID, self.getPick(traceID)))
ax1.plot(timeaxis, stream[0].data, 'k', label = 'trace')
ax1.plot([self.getPick(traceID), self.getPick(traceID)],
[min(stream[0].data),
max(stream[0].data)],
'r', label = 'mostlikely')
plt.legend()
ax2 = plt.subplot(2,1,2, sharex = ax1)
ax2.plot(hoscf.getTimeArray(), hoscf.getCF(), 'b', label = 'HOS')
ax2.plot(hoscf.getTimeArray(), aiccf.getCF(), 'g', label = 'AIC')
ax2.plot([self.getPick(traceID), self.getPick(traceID)],
if refresh == True:
xlim, ylim = ax.get_xlim(), ax.get_ylim()
ax.clear()
if refresh == True:
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.plot(hoscf.getTimeArray(), hoscf.getCF(), 'b', label = 'HOS')
ax.plot(hoscf.getTimeArray(), aiccf.getCF(), 'g', label = 'AIC')
ax.plot([self.getPick(traceID), self.getPick(traceID)],
[min(np.minimum(hoscf.getCF(), aiccf.getCF())),
max(np.maximum(hoscf.getCF(), aiccf.getCF()))],
'r', label = 'mostlikely')
ax2.plot([0, self.getPick(traceID)],
ax.plot([0, self.getPick(traceID)],
[folm * max(hoscf.getCF()), folm * max(hoscf.getCF())],
'm:', label = 'folm = %s' %folm)
plt.xlabel('Time [s]')
plt.legend()
ax.set_xlabel('Time [s]')
ax.legend()
def plot3dttc(self, step = 0.5, contour = False, plotpicks = False, method = 'linear', ax = None):
'''
Plots a 3D 'traveltime cone' as surface plot by interpolating on a regular grid over the traveltimes, not yet regarding the vertical offset of the receivers.
@ -622,7 +686,6 @@ class SeismicShot(object):
:param: method (optional), interpolation method; can be 'linear' (default) or 'cubic'
:type: 'string'
'''
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
@ -636,8 +699,8 @@ class SeismicShot(object):
y.append(self.getRecLoc(traceID)[1])
z.append(self.getPick(traceID))
xaxis = np.arange(min(x)+1, max(x), step)
yaxis = np.arange(min(y)+1, max(y), step)
xaxis = np.arange(min(x), max(x), step)
yaxis = np.arange(min(y), max(y), step)
xgrid, ygrid = np.meshgrid(xaxis, yaxis)
zgrid = griddata((x, y), z, (xgrid, ygrid), method = method)
@ -662,7 +725,7 @@ class SeismicShot(object):
plotmethod[method](*args)
def matshow(self, step = 0.5, method = 'linear', ax = None, plotRec = False, annotations = False):
def matshow(self, ax = None, step = 0.5, method = 'linear', plotRec = True, annotations = True, colorbar = True):
'''
Plots a 2D matrix of the interpolated traveltimes. This needs less performance than plot3dttc
@ -672,27 +735,32 @@ class SeismicShot(object):
:param: method (optional), interpolation method; can be 'linear' (default) or 'cubic'
:type: 'string'
:param: plotRec (optional), plot the receiver positions
:param: plotRec (optional), plot the receiver positions (colored scatter plot, should not be
deactivated because there might be receivers that are not inside the interpolated area)
:type: 'logical'
:param: annotations (optional), displays traceIDs as annotations
:type: 'logical'
'''
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
# plt.interactive('True')
x = []
y = []
z = []
x = []; xcut = []
y = []; ycut = []
z = []; zcut = []
tmin, tmax = self.getCut()
for traceID in self.pick.keys():
if self.getFlag(traceID) != 0:
x.append(self.getRecLoc(traceID)[0])
y.append(self.getRecLoc(traceID)[1])
z.append(self.getPick(traceID))
if self.getFlag(traceID) == 0 and self.getPickIncludeRemoved(traceID) is not None:
xcut.append(self.getRecLoc(traceID)[0])
ycut.append(self.getRecLoc(traceID)[1])
zcut.append(self.getPickIncludeRemoved(traceID))
xaxis = np.arange(min(x)+1, max(x), step)
yaxis = np.arange(min(y)+1, max(y), step)
xaxis = np.arange(min(x), max(x), step)
yaxis = np.arange(min(y), max(y), step)
xgrid, ygrid = np.meshgrid(xaxis, yaxis)
zgrid = griddata((x, y), z, (xgrid, ygrid), method='linear')
@ -700,14 +768,28 @@ class SeismicShot(object):
fig = plt.figure()
ax = plt.axes()
ax.imshow(zgrid, interpolation = 'none', extent = [min(x), max(x), min(y), max(y)])
if annotations == True:
for i, traceID in enumerate(self.pick.keys()):
if shot.picks[traceID] != None:
ax.annotate('%s' % traceID, xy=(x[i], y[i]), fontsize = 'x-small')
ax.matshow(zgrid, extent = [min(x), max(x), min(y), max(y)], origin = 'lower')
plt.text(0.45, 0.9, 'shot: %s' %self.getShotnumber(), transform = ax.transAxes)
sc = ax.scatter(x, y, c = z, s = 30, label = 'picked shots', vmin = tmin, vmax = tmax, linewidths = 1.5)
sccut = ax.scatter(xcut, ycut, c = zcut, s = 30, edgecolor = 'm', label = 'cut out shots', vmin = tmin, vmax = tmax, linewidths = 1.5)
if colorbar == True:
plt.colorbar(sc)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.plot(self.getSrcLoc()[0], self.getSrcLoc()[1],'*k', markersize = 15) # plot source location
if plotRec == True:
ax.plot(x, y, 'k.')
ax.scatter(x, y, c = z, s = 30)
if annotations == True:
for traceID in self.getTraceIDlist():
if self.getFlag(traceID) is not 0:
ax.annotate(' %s' %traceID , xy = (self.getRecLoc(traceID)[0], self.getRecLoc(traceID)[1]),
fontsize = 'x-small', color = 'k')
else:
ax.annotate(' %s' %traceID , xy = (self.getRecLoc(traceID)[0], self.getRecLoc(traceID)[1]),
fontsize = 'x-small', color = 'r')
plt.show()