Source code for dlk.utils.config
# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provide BaseConfig which provide the basic method for configs, and ConfigTool a general config(dict) process tool
"""
from typing import Any, Dict, Union, Callable, List, Tuple, Type
import json
import copy
import os
from dlk.utils.logger import Logger
from dlk.utils.register import Register
logger = Logger.get_logger()
[docs]class BaseConfig(object):
"""BaseConfig provide the basic function for all config"""
def __init__(self, config: Dict):
super(BaseConfig, self).__init__()
self._name = config.pop('_name')
[docs] def post_check(self, config, used=None):
"""check all the paras in config is used
Args:
config: paras
used: used paras
Returns:
None
Raises:
logger.warning("Unused")
"""
if not used:
used = []
def rec_pop(cur_node, trace):
"""recursive pop the node if the node == {} and the node path is in trace
"""
if len(trace) > 1:
rec_pop(cur_node[trace[0]], trace[1:])
if len(trace) == 1:
if cur_node.get(trace[0], {}) == {}:
cur_node.pop(trace[0])
config = copy.deepcopy(config)
parant_traces = set()
for key in used:
sp_key = key.split('.')
parant_traces.add(tuple(sp_key[:-1]))
cur_root = config
for key in sp_key[:-1]:
cur_root = cur_root[key]
cur_root.pop(sp_key[-1], None)
for trace in parant_traces:
rec_pop(config, trace)
if config:
logger.warning(f"In module '{self._name}', there are some params not be used: {config}")
[docs]class ConfigTool(object):
"""
This Class is not be used as much as I design.
"""
@staticmethod
def _inplace_update_dict(_base: Dict, _new: Dict):
"""use the _new dict inplace update the _base dict, recursively
if the _base['_name'] != _new["_name"], we will use _new cover the _base and logger a warning
otherwise, use _new update the _base recursively
Args:
_base: will be updated dict
_new: use _new update _base
Returns:
None
"""
for item in _new:
if (item not in _base) or (not isinstance(_base[item], Dict)):
# if item not in _base, or _base[item] is not Dict
_base[item] = _new[item]
elif isinstance(_base[item], Dict) and isinstance(_new[item], Dict):
if "_name" in _base[item] and "_name" in _new[item]:
if _base[item]['_name'] != _new[item]['_name']:
logger.warning(f"The Higher Config for {_new[item]['_name']} Coverd the Base {_base[item]['_name']} ")
_base[item] = _new[item]
continue
ConfigTool._inplace_update_dict(_base[item], _new[item])
else:
raise AttributeError("The base config and update config is not match. base: {}, new: {}. ".format(_base, _new))
[docs] @staticmethod
def do_update_config(config: dict, update_config: dict=None) ->Dict:
"""use the update_config dict update the config dict, recursively
see ConfigTool._inplace_update_dict
Args:
config: will be updated dict
update_confg: config: use _new update _base
Returns:
updated_config
"""
if not update_config:
update_config = {}
# BUG ?: if the config._name != update_config._name, should use the update_config conver the config wholely
config = copy.deepcopy(config)
ConfigTool._inplace_update_dict(config, update_config)
return config
[docs] @staticmethod
def get_leaf_module(module_register: Register, module_config_register: Register, module_name: str, config: Dict) -> Tuple[Any, object]:
"""get the module from module_register and module_config from module_config_register which name=module_name
Args:
module_register: register for module which has 'module_name'
module_config_register: config register for config which has 'module_name'
module_name: the module name which we want to get from register
Returns:
module(which name is module_name), module_config(which name is module_name)
"""
if isinstance(config, str):
name = config
extend_config = {}
else:
assert isinstance(config, dict), "{} config must be name(str) or config(dict), but you provide {}".format(module_name, config)
name = config.get('_name', "") # must provide _name_
extend_config = config
if not name:
raise KeyError('You must provide the {} name("name")'.format(module_name))
module_class, module_config_class = module_register.get(name), module_config_register.get(name)
if (not module_class) or not (module_config_class):
raise KeyError('The {} name {} is not registed.'.format(module_name, config))
module_config = module_config_class(extend_config)
return module_class, module_config
[docs] @staticmethod
def get_config_by_stage(stage:str, config:Dict)->Dict:
"""get the stage_config for special stage in provide config
it means the config of this stage equals to config[stage]
return config[config[stage]]
Config Example:
>>> config = {
>>> "train":{ //train、predict、online stage config, using '&' split all stages
>>> "data_pair": {
>>> "label": "label_id"
>>> },
>>> "data_set": { // for different stage, this processor will process different part of data
>>> "train": ['train', 'dev'],
>>> "predict": ['predict'],
>>> "online": ['online']
>>> },
>>> "vocab": "label_vocab", // usually provided by the "token_gather" module
>>> },
>>> "predict": "train",
>>> "online": ["train",
>>> {"vocab": "new_label_vocab"}
>>> ]
>>> }
>>> config.get_config['predict'] == config['predict'] == config['train']
Args:
stage: the stage, like 'train', 'predict', etc.
config: the base config which has different stage config
Returns:
stage_config
"""
config = config['config']
stage_config = config.get(stage, {})
if isinstance(stage_config, str):
stage_config = config.get(stage_config, {})
elif isinstance(stage_config, list):
assert len(stage_config) == 2
assert isinstance(stage_config[0], str)
assert isinstance(stage_config[1], dict)
base_config = config.get(stage_config[0], {})
stage_config = ConfigTool.do_update_config(base_config, stage_config[1])
return stage_config