Source code for dlk.core.base_module

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from typing import Dict, List, Callable, Any, Set
from dlk.core.modules import module_register, module_config_register
from dlk.utils.logger import Logger
from dlk.utils.config import BaseConfig
import abc

logger = Logger.get_logger()


[docs]class BaseModuleConfig(BaseConfig): """docstring for BaseLayerConfig""" def __init__(self, config: Dict): super(BaseModuleConfig, self).__init__(config) self._output_map = config['config'].pop("output_map", {}) self._input_map = config['config'].pop('input_map', {}) self._logits_gather_config = module_config_register.get("logits_gather")(config.pop("module@logits_gather", {}))
[docs]class ModuleOutputRenameMixin: """Just rename the output key name by config to adapt the input field of downstream module."""
[docs] def dict_rename(self, input: Dict, output_map: Dict[str, str])->Dict: """rename the key of input(dict) by output_map(name map) Args: input: will rename input output_map: name map Returns: renamed input """ if isinstance(input, dict): output = {} for key, value in input.items(): if key in output_map: output[output_map[key]] = value else: output[key] = value return output else: raise PermissionError("Not Defined")
[docs] def get_real_name(self, name: str, name_map: Dict[str, str])->str: """use the name_map to map the input name to real name Args: name: input_name name_map: name map Returns: real_name """ if name in name_map: return name_map[name] else: return name
[docs] def get_input_name(self, name: str)->str: """use config._input_map map the name to real name Args: name: input_name Returns: real_name """ return self.get_real_name(name, self.config._input_map)
[docs] def get_output_name(self, name:str)->str: """use config._output_map map the name to real name Args: name: output_name Returns: real_name """ return self.get_real_name(name, self.config._output_map)
[docs] def set_rename(self, input: Set, output_map: Dict[str, str])->Set: """rename all the name in input by output_map Args: input: a set of names output_map: name map Returns: renamed input """ if isinstance(input, set): output = set() for key in input: if key in output_map: output.add(output_map[key]) else: output.add(key) return output else: raise PermissionError("Not Defined")
[docs]class IModuleIO(metaclass=abc.ABCMeta): """interface for check the modules input and output"""
[docs] @abc.abstractmethod def provide_keys(self)->List[str]: """return all keys of the dict of the module returned Returns: all keys """ pass
[docs] @abc.abstractmethod def check_keys_are_provided(self, provide: List[str])->bool: """check this module required key are provided Returns: pass or not """ pass
[docs] def check_module_chain(self, module_list: List['BaseModule'])->bool: """check the interfaces of the list of modules are alignd or not. Args: module_list: a series modules Returns: pass or not Raises: ValueError: the check is not passed """ assert len(module_list) > 1 result = True for i in range(len(module_list)-1): result = result and module_list[i+1].check_keys_are_provided(module_list[i].provide_keys()) if not result: raise ValueError(f'The module "{module_list[i+1]._name}" is required "{", ".join(module_list[i+1].provide_keys())}", \ but the module "{module_list[i]._name}" provide "{", ".join(module_list[i].provide_keys())}"! ') return True
[docs]class IModuleStep(metaclass=abc.ABCMeta): """docstring for ModuleStepMixin"""
[docs] @abc.abstractmethod def predict_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do predict for one batch Args: inputs: one mini-batch inputs Returns: the predicts outputs """ raise NotImplementedError
[docs] @abc.abstractmethod def training_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do training for one batch Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ raise NotImplementedError
[docs] @abc.abstractmethod def validation_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do validataion for one batch Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ raise NotImplementedError
[docs] def test_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do test for one batch Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ return self.validation_step(inputs)
[docs]class BaseModel(nn.Module, ModuleOutputRenameMixin, IModuleIO, IModuleStep): """All pytorch models should inheritance this class """
[docs] def forward(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """all models should apply this method Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ raise NotImplementedError
[docs]class BaseModule(nn.Module, ModuleOutputRenameMixin, IModuleIO, IModuleStep): """All pytorch modules should inheritance this class """ def __init__(self, config: BaseModuleConfig): super(BaseModule, self).__init__() self._logits_gather = module_register.get("logits_gather")(config._logits_gather_config) self._provided_keys = set() self._required_keys = set() self._provide_keys = set() self.config = config # for better complete, you can rewrite this in child module
[docs] def provide_keys(self)->Set: """return all keys of the dict of the module returned Returns: all keys """ return self.set_rename(self._provide_keys, self.config._output_map).union(self._provided_keys)
[docs] def check_keys_are_provided(self, provide: Set[str])->None: """check this module required key are provided Returns: pass or not """ self._provided_keys = provide provide = self.set_rename(provide, self.config._input_map) for required_key in self._required_keys: if required_key not in provide: raise PermissionError(f"The {self.__class__.__name__} Module required '{required_key}' as input.")
[docs] def init_weight(self, method): """init the weight of submodules by 'method' Args: method: init method Returns: None """ for module in self.children(): module.apply(method)
[docs] def forward(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """all module should apply this method Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ raise NotImplementedError
[docs]class SimpleModule(BaseModule): """docstring for SimpleModule, SimpleModule, all train/predict/test/validation step call the forward"""
[docs] def forward(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """in simple module, all step fit to this method Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ raise NotImplementedError
[docs] def predict_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do predict for one batch Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ return self(inputs)
[docs] def training_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do train for one batch Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ return self(inputs)
[docs] def validation_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do validation for one batch Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ return self(inputs)
[docs] def test_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]: """do test for one batch Args: inputs: one mini-batch inputs Returns: one mini-batch outputs """ return self(inputs)