From b12d92eebb0c6145cfa81ed5bbda5be6d97f2440 Mon Sep 17 00:00:00 2001 From: Sebastian Wehling-Benatelli Date: Wed, 12 Apr 2023 21:22:58 +0200 Subject: [PATCH] bugfix: refactor get_owner and get_hash; add tests --- pylot/core/io/location.py | 4 ++-- pylot/core/io/phases.py | 6 +++--- pylot/core/util/utils.py | 25 +++++++++++++++++++++---- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/pylot/core/io/location.py b/pylot/core/io/location.py index 6fa2bbac..0fc5130d 100644 --- a/pylot/core/io/location.py +++ b/pylot/core/io/location.py @@ -1,7 +1,7 @@ from obspy import UTCDateTime from obspy.core import event as ope -from pylot.core.util.utils import getLogin, getHash +from pylot.core.util.utils import getLogin, get_hash def create_amplitude(pickID, amp, unit, category, cinfo): @@ -210,7 +210,7 @@ def create_resourceID(timetohash, restype, authority_id=None, hrstr=None): ''' assert isinstance(timetohash, UTCDateTime), "'timetohash' is not an ObsPy" \ "UTCDateTime object" - hid = getHash(timetohash) + hid = get_hash(timetohash) if hrstr is None: resID = ope.ResourceIdentifier(restype + '/' + hid[0:6]) else: diff --git a/pylot/core/io/phases.py b/pylot/core/io/phases.py index d7dc5856..3f8f96b5 100644 --- a/pylot/core/io/phases.py +++ b/pylot/core/io/phases.py @@ -16,7 +16,7 @@ from pylot.core.io.inputs import PylotParameter from pylot.core.io.location import create_event, \ create_magnitude from pylot.core.pick.utils import select_for_phase, get_quality_class -from pylot.core.util.utils import getOwner, full_range, four_digits, transformFilterString4Export, \ +from pylot.core.util.utils import get_owner, full_range, four_digits, transformFilterString4Export, \ backtransformFilterString, loopIdentifyPhase, identifyPhase @@ -58,7 +58,7 @@ def readPILOTEvent(phasfn=None, locfn=None, authority_id='RUB', **kwargs): if phasfn is not None and os.path.isfile(phasfn): phases = sio.loadmat(phasfn) phasctime = UTCDateTime(os.path.getmtime(phasfn)) - phasauthor = getOwner(phasfn) + phasauthor = get_owner(phasfn) else: phases = None phasctime = None @@ -66,7 +66,7 @@ def readPILOTEvent(phasfn=None, locfn=None, authority_id='RUB', **kwargs): if locfn is not None and os.path.isfile(locfn): loc = sio.loadmat(locfn) locctime = UTCDateTime(os.path.getmtime(locfn)) - locauthor = getOwner(locfn) + locauthor = get_owner(locfn) else: loc = None locctime = None diff --git a/pylot/core/util/utils.py b/pylot/core/util/utils.py index f0342276..ede30ab6 100644 --- a/pylot/core/util/utils.py +++ b/pylot/core/util/utils.py @@ -301,7 +301,7 @@ def fnConstructor(s): if type(s) is str: s = s.split(':')[-1] else: - s = getHash(UTCDateTime()) + s = get_hash(UTCDateTime()) badchars = re.compile(r'[^A-Za-z0-9_. ]+|^\.|\.$|^ | $|^$') badsuffix = re.compile(r'(aux|com[1-9]|con|lpt[1-9]|prn)(\.|$)') @@ -473,16 +473,25 @@ def backtransformFilterString(st): return st -def getHash(time): +def get_hash(time): """ takes a time object and returns the corresponding SHA1 hash of the formatted date string :param time: time object for which a hash should be calculated :type time: `~obspy.core.utcdatetime.UTCDateTime` :return: SHA1 hash :rtype: str + + >>> time = UTCDateTime(0) + >>> get_hash(time) + '7627cce3b1b58dd21b005dac008b34d18317dd15' + >>> get_hash(0) + Traceback (most recent call last): + ... + AssertionError: 'time' is not an ObsPy UTCDateTime object """ + assert isinstance(time, UTCDateTime), '\'time\' is not an ObsPy UTCDateTime object' hg = hashlib.sha1() - hg.update(time.strftime('%Y-%m-%d %H:%M:%S.%f')) + hg.update(time.strftime('%Y-%m-%d %H:%M:%S.%f').encode('utf-8')) return hg.hexdigest() @@ -496,13 +505,21 @@ def getLogin(): return getpass.getuser() -def getOwner(fn): +def get_owner(fn): """ takes a filename and return the login ID of the actual owner of the file :param fn: filename of the file tested :type fn: str :return: login ID of the file's owner :rtype: str + + >>> import tempfile + >>> with tempfile.NamedTemporaryFile() as tmpfile: + ... tmpfile.write(b'') and True + ... tmpfile.flush() + ... get_owner(tmpfile.name) == os.path.expanduser('~').split('/')[-1] + 0 + True """ system_name = platform.system() if system_name in ["Linux", "Darwin"]: