Spaces:
Build error
Build error
from typing import Mapping, Any | |
import importlib | |
from torch import nn | |
def get_obj_from_str(string: str, reload: bool=False) -> object: | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def instantiate_from_config(config: Mapping[str, Any]) -> object: | |
if not "target" in config: | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def disabled_train(self: nn.Module) -> nn.Module: | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
def frozen_module(module: nn.Module) -> None: | |
module.eval() | |
module.train = disabled_train | |
for p in module.parameters(): | |
p.requires_grad = False | |
def load_state_dict(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None: | |
state_dict = state_dict.get("state_dict", state_dict) | |
is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.") | |
is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.") | |
if ( | |
is_model_key_starts_with_module and | |
(not is_state_dict_key_starts_with_module) | |
): | |
state_dict = {f"module.{key}": value for key, value in state_dict.items()} | |
if ( | |
(not is_model_key_starts_with_module) and | |
is_state_dict_key_starts_with_module | |
): | |
state_dict = {key[len("module."):]: value for key, value in state_dict.items()} | |
model.load_state_dict(state_dict, strict=strict) | |