Spaces:
Runtime error
Runtime error
| import inspect | |
| from copy import deepcopy | |
| from functools import wraps | |
| import torch.nn as nn | |
| def serialize(init): | |
| parameters = list(inspect.signature(init).parameters) | |
| def new_init(self, *args, **kwargs): | |
| params = deepcopy(kwargs) | |
| for pname, value in zip(parameters[1:], args): | |
| params[pname] = value | |
| config = {"class": get_classname(self.__class__), "params": dict()} | |
| specified_params = set(params.keys()) | |
| for pname, param in get_default_params(self.__class__).items(): | |
| if pname not in params: | |
| params[pname] = param.default | |
| for name, value in list(params.items()): | |
| param_type = "builtin" | |
| if inspect.isclass(value): | |
| param_type = "class" | |
| value = get_classname(value) | |
| config["params"][name] = { | |
| "type": param_type, | |
| "value": value, | |
| "specified": name in specified_params, | |
| } | |
| setattr(self, "_config", config) | |
| init(self, *args, **kwargs) | |
| return new_init | |
| def load_model(config, **kwargs): | |
| model_class = get_class_from_str(config["class"]) | |
| model_default_params = get_default_params(model_class) | |
| model_args = dict() | |
| for pname, param in config["params"].items(): | |
| value = param["value"] | |
| if param["type"] == "class": | |
| value = get_class_from_str(value) | |
| if pname not in model_default_params and not param["specified"]: | |
| continue | |
| assert pname in model_default_params | |
| if not param["specified"] and model_default_params[pname].default == value: | |
| continue | |
| model_args[pname] = value | |
| model_args.update(kwargs) | |
| return model_class(**model_args) | |
| def get_config_repr(config): | |
| config_str = f'Model: {config["class"]}\n' | |
| for pname, param in config["params"].items(): | |
| value = param["value"] | |
| if param["type"] == "class": | |
| value = value.split(".")[-1] | |
| param_str = f"{pname:<22} = {str(value):<12}" | |
| if not param["specified"]: | |
| param_str += " (default)" | |
| config_str += param_str + "\n" | |
| return config_str | |
| def get_default_params(some_class): | |
| params = dict() | |
| for mclass in some_class.mro(): | |
| if mclass is nn.Module or mclass is object: | |
| continue | |
| mclass_params = inspect.signature(mclass.__init__).parameters | |
| for pname, param in mclass_params.items(): | |
| if param.default != param.empty and pname not in params: | |
| params[pname] = param | |
| return params | |
| def get_classname(cls): | |
| module = cls.__module__ | |
| name = cls.__qualname__ | |
| if module is not None and module != "__builtin__": | |
| name = module + "." + name | |
| return name | |
| def get_class_from_str(class_str): | |
| components = class_str.split(".") | |
| mod = __import__(".".join(components[:-1])) | |
| for comp in components[1:]: | |
| mod = getattr(mod, comp) | |
| return mod | |