Source code for towhee.hparam.hyperparameter

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import threading

from typing import Any, Dict, Set
from typing import Callable
# pylint: disable=pointless-string-statement
"""
Trackers that record all hyperparameter accesses.
"""
_read_tracker: Set[str] = set()
_write_tracker: Set[str] = set()


[docs]def reads(): """ Get hyperparameter read operations. Returns: List[str]: hyperparameter read operations Examples: >>> _read_tracker.clear() >>> hp = HyperParameter(a=1, b={'c': 2}) >>> reads() # no read operations [] >>> hp.a # accessing parameter directly 1 >>> reads() # not tracked [] >>> hp().a() # accessing with accessor 1 >>> reads() # tracked! ['a'] """ retval = list(_read_tracker) retval.sort() return retval
[docs]def writes(): """ Get hyperparameter write operations. Returns: List[str]: hyperparameter write operations Examples: >>> _write_tracker.clear() >>> hp = HyperParameter(a=1, b={'c': 2}) >>> writes() [] >>> hp.a = 1 >>> writes() [] >>> hp().a = 1 >>> hp().a.b.c = 1 >>> writes() ['a', 'a.b.c'] """ retval = list(_write_tracker) retval.sort() return retval
[docs]def all_params(): """ Get all tracked hyperparameters. """ retval = list(_read_tracker.union(_write_tracker)) retval.sort() return retval
class _Accessor(dict): """ Helper for accessing hyper-parameters. When reading an undefined parameter, the accessor will: 1. return false in `if` statement: >>> params = HyperParameter() >>> if not params.undefined_int: print("parameter undefined") parameter undefined 2. support default value for undefined parameter >>> params = HyperParameter() >>> params.undefined_int.get_or_else(10) 10 3. support to create nested parameter: >>> params = HyperParameter() >>> params.undefined_object.undefined_prop = 1 >>> print(params) {'undefined_object': {'undefined_prop': 1}} """ def __init__(self, root, path=None): super().__init__() self._root = root self._path = path def get_or_else(self, default: Any = None): """ Get value for the parameter, or get default value if the parameter is not defined. """ _read_tracker.add(self._path) value = self._root.get(self._path) return default if not value else value def __getattr__(self, name: str) -> Any: # _path and _root are not allowed as keys for user. if name in ['_path', '_root']: return self[name] if self._path: name = '{}.{}'.format(self._path, name) return _Accessor(self._root, name) def __setattr__(self, name: str, value: Any): # _path and _root are not allowed as keys for user. if name in ['_path', '_root']: return self.__setitem__(name, value) full_name = '{}.{}'.format(self._path, name) if self._path is not None else name _write_tracker.add(full_name) root = self._root root.put(full_name, value) return value def __str__(self): return '' def __bool__(self): return False def __call__(self, default: Any = None) -> Any: """ shortcut for get_or_else """ return self.get_or_else(default) __nonzero__ = __bool__
[docs]class DynamicDispatch: """ Dynamic call dispatch Examples: >>> @dynamic_dispatch ... def debug_print(*args, **kws): ... hp = param_scope() ... name = hp._name ... index = hp._index ... return (name, index, args, kws) >>> debug_print() (None, None, (), {}) >>> debug_print.a() ('a', None, (), {}) >>> debug_print.a.b.c() ('a.b.c', None, (), {}) >>> debug_print[1]() (None, 1, (), {}) >>> debug_print[1,2]() (None, (1, 2), (), {}) >>> debug_print(1,2, a=1,b=2) (None, None, (1, 2), {'a': 1, 'b': 2}) >>> debug_print.a.b.c[1,2](1, 2, a=1, b=2) ('a.b.c', (1, 2), (1, 2), {'a': 1, 'b': 2}) """
[docs] def __init__(self, func: Callable, name=None, index=None): self._func = func self._name = name self._index = index
[docs] def __call__(self, *args, **kws) -> Any: with param_scope(_index=self._index, _name=self._name): return self._func(*args, **kws)
def __getattr__(self, name: str) -> Any: if self._name is not None: name = '{}.{}'.format(self._name, name) return dynamic_dispatch(self._func, name, self._index) def __getitem__(self, index): return dynamic_dispatch(self._func, self._name, index)
[docs]def dynamic_dispatch(func, name=None, index=None): """Wraps function with a class to allow __getitem__ and __getattr__ on a function. """ new_class = type(func.__name__, ( DynamicDispatch, object, ), dict(__doc__=func.__doc__)) return new_class(func, name, index)
[docs]class HyperParameter(dict): """ HyperParameter is an extended dict with features for better parameter management. A HyperParameter can be created with: >>> hp = HyperParameter(param1=1, param2=2, obj1={'propA': 'A'}) or >>> hp = HyperParameter(**{'param1': 1, 'param2': 2, 'obj1': {'propA': 'A'}}) Once the HyperParameter object is created, you can access the values using the object-style api: >>> hp.param1 1 >>> hp.obj1.propA 'A' or using the dict-style api (for legacy codes): >>> hp['param1'] 1 >>> hp['obj1']['propA'] 'A' The object-style api also support creating or updating the parameters: >>> hp.a.b.c = 1 which avoid maintaining the dict data manually like this: >>> hp = {} >>> if 'a' not in hp: hp['a'] = {} >>> if 'b' not in hp['a']: hp['a']['b'] = {} >>> hp['a']['b']['c'] = 1 You can also create a parameter with a string name: >>> hp = HyperParameter() >>> hp.put('a.b.c', 1) """
[docs] def __init__(self, **kws): super().__init__() self.update(kws)
[docs] def update(self, kws): for k, v in kws.items(): if isinstance(v, dict): if k in self and isinstance(self[k], dict): vv = HyperParameter(**self[k]) vv.update(v) v = vv else: v = HyperParameter(**v) self[k] = v
[docs] def put(self, name: str, value: Any): """ put/update a parameter with a string name Args: name (str): parameter name, 'obj.prop' is supported value (Any): parameter value Examples: >>> cfg = HyperParameter() >>> cfg.put('param1', 1) >>> cfg.put('obj1.propA', 'A') >>> cfg.param1 1 >>> cfg.obj1.propA 'A' """ path = name.split('.') obj = self for p in path[:-1]: if p not in obj or (not isinstance(obj[p], dict)): obj[p] = HyperParameter() obj = obj[p] obj[path[-1]] = safe_numeric(value)
[docs] def get(self, name: str) -> Any: """ get a parameter by a string name Args: name (str): parameter name Returns: Any: parameter value Examples: >>> cfg = HyperParameter(a=1, b = {'c':2, 'd': 3}) >>> cfg.get('a') 1 >>> cfg.get('b.c') 2 """ path = name.split('.') obj = self for p in path[:-1]: if p not in obj: return _Accessor(obj, p) obj = obj[p] return obj[path[-1]] if path[-1] in obj else _Accessor(self, name)
[docs] def __setitem__(self, key, value): """ set value and convert the value into `HyperParameter` if necessary """ if isinstance(value, dict): return dict.__setitem__(self, key, HyperParameter(**value)) return dict.__setitem__(self, key, value)
[docs] def __getattr__(self, name): """ read parameter with object-style api Examples: for simple parameters: >>> hp = HyperParameter(a=1, b = {'c':2, 'd': 3}) >>> hp.a 1 for nested parameters: >>> hp.b.c 2 >>> getattr(hp, 'b.c') 2 """ return self.get(name)
# if name in self.keys(): # return self[name] # else: # if name in self.__dict__.keys(): # return self.__dict__[name] # return _Accessor(self, name) def __setattr__(self, name, value): """ create/update parameter with object-style api Examples: >>> hp = HyperParameter(a=1, b = {'c':2, 'd': 3}) >>> hp.e = 4 >>> hp['e'] 4 >>> setattr(hp, 'A.B.C', 1) >>> hp.A.B.C 1 """ self.put(name, value) #self[name] = value
[docs] def __call__(self) -> Any: """ Return a parameter accessor. Returns: Any: holder of the current parameter Examples: >>> cfg = HyperParameter(a=1, b = {'c':2, 'd': 3}) >>> cfg().a.get_or_else('default') # default value for simple parameter 1 >>> cfg().b.c.get_or_else('default') # default value for nested parameter 2 >>> cfg().b.undefined.get_or_else('default') 'default' """ return _Accessor(self, None)
[docs] def dispatch(self, callback: Callable = None): """ Return a call holder. Examples: >>> def debug_print(path, index, *arg, **kws): ... return (path, index, arg, kws) >>> ch = param_scope().dispatch(debug_print) >>> ch.my.foo(a=1,b=2) ('my.foo', None, (), {'a': 1, 'b': 2}) >>> ch.myspace2.gee(c=1,d=2) ('myspace2.gee', None, (), {'c': 1, 'd': 2}) """ # pylint: disable=protected-access @dynamic_dispatch def wrapper(*arg, **kws): with param_scope() as hp: name = hp._name index = hp._index return callback(name, index, *arg, **kws) return wrapper
def callholder(self, callback: Callable = None): return self.dispatch(callback)
[docs] @staticmethod def loads(s): """ Load parameters from JSON string, similar as `json.loads`. """ obj = json.loads(s) return HyperParameter(**obj)
[docs] @staticmethod def load(f): """ Load parameters from json file, similar as `json.load`. """ obj = json.load(f) return HyperParameter(**obj)
[docs]class param_scope(HyperParameter): # pylint: disable=invalid-name """ thread-safe scoped hyperparameter Examples: create a scoped HyperParameter >>> with param_scope(**{'a': 1, 'b': 2}) as cfg: ... print(cfg.a) 1 read parameter in a function >>> def foo(): ... with param_scope() as cfg: ... return cfg.a >>> with param_scope(**{'a': 1, 'b': 2}) as cfg: ... foo() # foo should get cfg using a with statement 1 update some config only in new scope >>> with param_scope(**{'a': 1, 'b': 2}) as cfg: ... cfg.b ... with param_scope(**{'b': 3}) as cfg2: ... cfg2.b 2 3 """ tls = threading.local()
[docs] def __init__(self, *args, **kws): # Check if nested param_scope, if so, update current scope to include previous. if hasattr(param_scope.tls, 'history') and len(param_scope.tls.history) > 0: self.update(param_scope.tls.history[-1]) self.update(kws) for line in args: if '=' in line: k, v = line.split('=', 1) self.put(k, v)
def __enter__(self): if not hasattr(param_scope.tls, 'history'): param_scope.tls.history = [] param_scope.tls.history.append(self) return param_scope.tls.history[-1] def __exit__(self, exc_type, exc_value, traceback): param_scope.tls.history.pop()
[docs] @staticmethod def init(params): """ init param_scope for a new thread. """ if not hasattr(param_scope.tls, 'history'): param_scope.tls.history = [] param_scope.tls.history.append(params)
""" Tracker callback for auto_param """ _callback: Callable = None
[docs]def set_auto_param_callback(func: Callable[[Dict[str, Any]], None]): """ report hyperparameter value to a tracker, for example, `mlflow.tracking` """ global _callback _callback = func
[docs]def auto_param(name_or_func): """ Convert keyword arguments into hyperparameters Examples: >>> @auto_param ... def foo(a, b=2, c='c', d=None): ... print(a, b, c, d) >>> foo(1) 1 2 c None >>> with param_scope('foo.b=3'): ... foo(2) 2 3 c None classes are also supported: >>> @auto_param ... class foo: ... def __init__(self, a, b=2, c='c', d=None): ... print(a, b, c, d) >>> obj = foo(1) 1 2 c None >>> with param_scope('foo.b=3'): ... obj = foo(2) 2 3 c None >>> @auto_param('my') ... def foo(a, b=2, c='c', d=None): ... print(a, b, c, d) >>> foo(1) 1 2 c None >>> with param_scope('foo.b=3'): ... foo(2) 2 2 c None >>> with param_scope('my.foo.b=3'): ... foo(2) 2 3 c None """ if callable(name_or_func): return auto_param(None)(name_or_func) def wrapper(func): predef_kws = {} predef_val = {} if name_or_func is None: namespace = func.__name__ else: namespace = name_or_func + '.' + func.__name__ signature = inspect.signature(func) for k, v in signature.parameters.items(): if v.default != v.empty: name = '{}.{}'.format(namespace, k) predef_kws[k] = name _read_tracker.add(name) predef_val[name] = v.default def inner(*arg, **kws): with param_scope() as hp: local_params = {} for k, v in predef_kws.items(): if getattr( hp(), v).get_or_else(None) is not None and k not in kws: kws[k] = hp.get(v) local_params[v] = hp.get(v) else: local_params[v] = predef_val[v] if _callback is not None: _callback(local_params) return func(*arg, **kws) return inner return wrapper
[docs]def safe_numeric(value): if isinstance(value, str): try: return int(value) except: # pylint: disable=bare-except pass try: return float(value) except: # pylint: disable=bare-except pass return value