diff --git a/xpflow.py b/xpflow.py index 51ba0a6..a9c3e44 100644 --- a/xpflow.py +++ b/xpflow.py @@ -3,6 +3,8 @@ from easydict import EasyDict import hashlib import json +from sorcery import dict_of +import os, sys, traceback def override(xp): import argparse, sys @@ -96,3 +98,35 @@ def __str__(self): def __len__(self): return len([x for x in self]) + +class NoPrint: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + +class Catch: + def __init__(self, exceptions=[], exit_fn=lambda:None): + self.allowed_exceptions = exceptions + self.encountered_expcetions=[] + self.exit_fn=exit_fn + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, tb): + self.exit_fn() + if exception_type==KeyboardInterrupt: + return False + global _EXCEPTIONS + try: + _EXCEPTIONS + except: + _EXCEPTIONS=[] + + if exception_type and (exception_type in self.allowed_exceptions or not self.allowed_exceptions): + print(f"{exception_type.__name__} swallowed!",exception_value,traceback.print_tb(tb)) + _EXCEPTIONS+=[dict_of(exception_type,exception_value,tb)] + return True