|
from collections import defaultdict |
|
from typing import TypeVar, Type, Dict, List |
|
import importlib |
|
import logging |
|
|
|
logger = logging.getLogger("toolbox") |
|
|
|
T = TypeVar("T") |
|
|
|
|
|
class Registrable(object): |
|
_registry: Dict[Type, Dict[str, Type]] = defaultdict(dict) |
|
default_implementation: str = None |
|
register_name: str = "unknown" |
|
|
|
@classmethod |
|
def register(cls: Type[T], name: str, exist_ok=False): |
|
registry = Registrable._registry[cls] |
|
def add_subclass_to_registry(subclass: Type[T]): |
|
|
|
setattr(subclass, "register_name", name) |
|
if name in registry: |
|
if exist_ok: |
|
message = (f"{name} has already been registered as {registry[name].__name__}, but " |
|
f"exist_ok=True, so overwriting with {cls.__name__}") |
|
|
|
else: |
|
message = (f"Cannot register {name} as {cls.__name__}; " |
|
f"name already in use for {registry[name].__name__}") |
|
raise ValueError(message) |
|
registry[name] = subclass |
|
return subclass |
|
return add_subclass_to_registry |
|
|
|
@classmethod |
|
def by_name(cls: Type[T], name: str) -> Type[T]: |
|
|
|
if name in Registrable._registry[cls]: |
|
return Registrable._registry[cls].get(name) |
|
else: |
|
raise ValueError( |
|
f"{name} is not a registered name for {cls.__name__}. " |
|
f"the available is: [{Registrable._registry[cls].keys()}]" |
|
) |
|
|
|
|
|
@classmethod |
|
def list_available(cls) -> List[str]: |
|
keys = list(Registrable._registry[cls].keys()) |
|
default = cls.default_implementation |
|
|
|
if default is None: |
|
return keys |
|
elif default not in keys: |
|
message = "Default implementation %s is not registered" % default |
|
raise ValueError(message) |
|
else: |
|
return [default] + [k for k in keys if k != default] |
|
|