Source code for dlk.utils.parser

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hjson
import os
import copy
from typing import Callable, List, Dict, Union
from dlk.utils import register
from dlk.utils.register import Register
from dlk.utils.config import ConfigTool
from dlk.utils.logger import Logger
from dlk.utils.get_root import get_root
from dlk.utils.register import Register
from dlk.core import (
    embedding_config_register,
    embedding_register,
    callback_config_register,
    callback_register,
    decoder_config_register,
    decoder_register,
    encoder_config_register,
    encoder_register,
    imodel_config_register,
    imodel_register,
    initmethod_config_register,
    initmethod_register,
    model_config_register,
    model_register,
    scheduler_config_register,
    loss_config_register,
    loss_register,
    module_config_register,
    module_register,
    optimizer_config_register,
    optimizer_register,
    scheduler_register
)
from dlk.data import (
    datamodule_config_register,
    datamodule_register,
    postprocessor_config_register,
    postprocessor_register,
    processor_config_register,
    processor_register,
    subprocessor_config_register,
    subprocessor_register
)
from dlk.managers import (
    manager_config_register, 
    manager_register
)
import json

logger = Logger.get_logger()

config_parser_register = Register("Config parser register")


[docs]class LinkUnionTool(object): """Assisting tool for parsering the "_link" of config. All the function named the top level has high priority than low level This class is mostly for resolve the confilicts of the low and high level register links. """ def __init__(self): self.link_union = {}
[docs] def find(self, key: str): """find the root of the key Args: key: a token Returns: the root of the key """ if key not in self.link_union: return None if self.link_union[key] != key: return self.find(self.link_union[key]) return self.link_union[key]
[docs] def low_level_union(self, link_from: str, link_to: str): """union the low level link_from->link_to pair On the basis of the high-level links, this function is used to regist low-level link If `link-from` and `link-to` were all not appeared at before, they will be directly registed. If only one of the `link-from` and `link-to` appeared, the value of the `link-from` and `link-to` will be overwritten by the corresponding value of the upper level, If both `link-from` and `link-to` appeared at before, and if they linked the same value, we will do nothing, otherwise `RAISE AN ERROR` Args: link_from: the link-from key link_to: the link-to key Returns: None """ if self.find(link_from) and self.find(link_to): # all has been linked if self.find(link_from)!= self.find(link_to): raise PermissionError(f"The {link_from} and {link_to} has been linked to different values, but now you want to link them together.") elif self.find(link_from) == link_to: logger.warning(f"High level config has the link '{link_to} -> {link_from}', and the low level reversed link '{link_from} -> {link_to}' is been ignored.") else: return elif self.find(link_to): # only link_to has been linked logger.warning(f"Parameter '{link_to}' has been linked in high level config, the link '{link_from} -> {link_to}' is invalid, and the real link is been reversed as '{link_to} -> {link_from}'.") self.link_union[link_from] = self.find(link_to) elif self.find(link_from): # only link_from has been linked self.link_union[link_to] = self.find(link_from) else: self.link_union[link_from] = link_from self.link_union[link_to] = link_from
[docs] def top_level_union(self, link_from: str, link_to: str): """union the top level link_from->link_to pair Register the 'link'(`link-from` -> `link-to`) in the same(top) level config should be merged using `top_level_union` Parameters are not allowed to be assigned repeatedly (the same parameter cannot appear more than once in the `link-to` position, otherwise it will cause ambiguity.) Args: link_from: the link-from key link_to: the link-to key Returns: None """ if link_from not in self.link_union: self.link_union[link_from] = link_from assert link_to not in self.link_union, f"{link_to} is repeated assignment" self.link_union[link_to] = self.find(link_from)
[docs]class BaseConfigParser(object): """BaseConfigParser The config parser order is: inherit -> search -> link If some config is marked to "*@*", this means the para has not default value, you must coverd it(like 'label_nums', etc.). """ def __init__(self, config_file: Union[str, Dict, List], config_base_dir: str="", register: Register=None): super(BaseConfigParser, self).__init__() if isinstance(config_file, str): if config_file == '*@*': self.config_file = "*@*" return try: if os.path.isfile(os.path.join(get_root(), config_base_dir, config_file+'.hjson')): self.config_file = self.load_hjson_file(os.path.join(get_root(), config_base_dir, config_file+'.hjson')) else: self.config_file = copy.deepcopy(register.get(config_file).default_config) except Exception as e: logger.error(f"You must provide a configure file at {os.path.join(get_root(), config_base_dir, config_file)} or provide `default_config` as a class variable") raise KeyError(e) elif isinstance(config_file, Dict): self.config_file = config_file else: raise KeyError('The config file must be str or dict. You provide {}.'.format(config_file)) self.config_name = self.config_file.pop('_name', "") self.search = self.config_file.pop('_search', {}) base = self.config_file.pop('_base', "") self.base_config = {} if base: self.base_config = self.get_base_config(base) if "_focus" in self.config_file: self.base_config['_focus'] = self.config_file.pop('_focus') # merge base and current config _link link_union = LinkUnionTool() link_union.register_top_links(self.config_file.pop('_link', {})) link_union.register_low_links(self.base_config.pop('_link', {})) self.config_file['_link'] = link_union.get_links() if self.base_config and self.config_name: raise PermissionError("You should put the _name to the leaf config.") self.modules = self.config_file
[docs] @classmethod def get_base_config(cls, config_name: str)->Dict: """get the base config use the config_name Args: config_name: the config name Returns: config of the config_name """ base_config = cls(config_name).parser(parser_link=False) if len(base_config)>1: raise PermissionError("The base config don't support _search now.") if base_config: return base_config[0] return {}
[docs] def parser_with_check(self, parser_link=True)->List[Dict]: """parser the config and check the config is valid Args: parser_link: whether parser the links Returns: all valided configs """ configs = self.parser(parser_link) self.check_config(configs) return configs
[docs] def parser(self, parser_link=True) -> List: """parser the config Args: parser_link: whether parser the links Returns: all valided configs """ if self.config_file == '*@*': return ['*@*'] # parser submodules get submodules config modules_config = self.map_to_submodule(self.modules, self.get_kind_module_base_config) # expand all submodules to combine a set of module configs possible_config_list = self.get_named_list_cartesian_prod(modules_config) # using specifical module config to update base_config if possible_config_list: possible_config_list = [ConfigTool.do_update_config(self.base_config, possible_config) for possible_config in possible_config_list] else: possible_config_list = [self.base_config] # flat all search paras possible_config_list_list = [self.flat_search(self.search, possible_config) for possible_config in possible_config_list] all_possible_config_list = [] for possible_config_list in possible_config_list_list: all_possible_config_list.extend(possible_config_list) # link paras if parser_link: for possible_config in all_possible_config_list: all_level_links = self.collect_link(possible_config) link_union = LinkUnionTool() for i in range(len(all_level_links)): cur_level_links = all_level_links[i] link_union.register_low_links(cur_level_links) self.config_link_para(link_union.get_links(), possible_config) return_list = [] for possible_config in all_possible_config_list: config = copy.deepcopy(possible_config) if self.config_name: config['_name'] = self.config_name return_list.append(config) if self.is_rep_config(return_list): logger.warning(f"The Configures is Repeated, Please Check The Configures Carefully.") for i, config in enumerate(return_list): logger.info(f"The {i}th Configure is:") logger.info(json.dumps(config, indent=2, ensure_ascii=False)) raise ValueError('REPEAT CONFIG') return return_list
[docs] def get_kind_module_base_config(self, abstract_config: Union[dict, str], kind_module: str="") -> List[dict]: """get the whole config of 'kind_module' by given abstract_config Args: abstract_config: will expanded config kind_module: the module kind, like 'embedding', 'subprocessor', which registed in config_parser_register Returns: parserd config (whole config) of abstract_config """ return config_parser_register.get(kind_module)(abstract_config).parser(parser_link=False)
[docs] def map_to_submodule(self, config: dict, map_fun: Callable) -> Dict: """map the map_fun to all submodules in config use the map_fun to process all the modules Args: config: a dict of submodules, the key is the module kind wich registed in config_parser_register map_fun: use the map_fun process the submodule Returns: TODO """ modules_config = {} for kind_module in config: modules_config[kind_module] = map_fun(config[kind_module], kind_module) return modules_config
[docs] def load_hjson_file(self, file_path: str) -> Dict: """load hjson file from file_path and return a Dict Args: file_path: the file path Returns: loaded dict """ json_file = hjson.load(open(file_path), object_pairs_hook=dict) return json_file
[docs] def get_cartesian_prod(self, list_of_list_of_dict: List[List[Dict]]) -> List[List[Dict]]: """get catesian prod from two lists Args: list_of_list_of_dict: [[config_a1, config_a2], [config_b1, config_b2]] Returns: [[config_a1, config_b1], [config_a1, config_b2], [config_a2, config_b1], [config_a2, config_b2]] """ if len(list_of_list_of_dict) <= 1: return [copy.deepcopy(dic) for dic in list_of_list_of_dict] cur_result = list_of_list_of_dict[0] reserve_result = self.get_cartesian_prod(list_of_list_of_dict[1:]) result = [] for cur_config in cur_result: for reserve in reserve_result: result.append([copy.deepcopy(cur_config)]+copy.deepcopy(reserve)) return result
[docs] @staticmethod def check_config(configs: Union[Dict, List[Dict]]) -> None: """check all config is valid. check all "*@*" is replaced to correct value. Args: configs: TODO Returns: None Raises: ValueError """ def _check(config): """check the "*@*" is in config or not """ for key in config: if isinstance(config[key], dict): _check(config[key]) if config[key] == '*@*': raise ValueError(f'In Config: \n {json.dumps(config, indent=4, ensure_ascii=False)}\n The must be provided key "{key}" marked with "*@*" is not provided.') if isinstance(configs, list): for config in configs: _check(config) else: _check(configs)
[docs] @staticmethod def get_named_list_cartesian_prod(dict_of_list: Dict[str, List]=None) -> List[Dict]: """get catesian prod from named lists Args: dict_of_list: {'name1': [1,2,3], 'name2': "list(range(1, 4))"} Returns: [{'name1': 1, 'name2': 1}, {'name1': 1, 'name2': 2}, {'name1': 1, 'name2': 3}, ...] """ if not dict_of_list: dict_of_list = {} if len(dict_of_list) == 0: return [] dict_of_list = copy.deepcopy(dict_of_list) cur_name, cur_paras = dict_of_list.popitem() cur_para_search_list = [] if isinstance(cur_paras, str): cur_paras = eval(cur_paras) assert isinstance(cur_paras, list), f"The search candidates must be list, but you provide {cur_paras}({type(cur_paras)})" for para in cur_paras: cur_para_search_list.append({cur_name: para}) if len(dict_of_list) == 0: return cur_para_search_list reserve_para_list = BaseConfigParser.get_named_list_cartesian_prod(dict_of_list) all_config_list = [] for cur_config in cur_para_search_list: for reserve_config in reserve_para_list: _cur_config = copy.deepcopy(cur_config) _cur_config.update(copy.deepcopy(reserve_config)) all_config_list.append(_cur_config) return all_config_list
[docs] def is_rep_config(self, list_of_dict: List[dict]) -> bool: """check is there a repeat config in list Args: list_of_dict: a list of dict Returns: has repeat or not """ # using json.dumps + sort_keys to guarantee the same dict to the same string represatation list_of_str = [json.dumps(dic, sort_keys=True, ensure_ascii=False) for dic in list_of_dict] if len(list_of_dict) == len(set(list_of_str)): return False else: return True
[docs]@config_parser_register('config') class ConfigConfigParser(BaseConfigParser): """ConfigConfigParser""" def __init__(self, config_file): super(ConfigConfigParser, self).__init__(config_file, config_base_dir='NONEPATH') if self.base_config: raise AttributeError('The paras config do not support _base.') if self.config_name: raise AttributeError('The paras config do not support _name.')
[docs] def parser(self, parser_link=True): """parser the config config support _search and _link Args: parser_link: whether parser the links Returns: all valided configs """ config_list = self.flat_search(self.search, self.modules) # link paras if parser_link: for possible_config in config_list: all_level_links = self.collect_link(possible_config) link_union = LinkUnionTool() for i in range(len(all_level_links)): cur_level_links = all_level_links[i] link_union.register_low_links(cur_level_links) self.config_link_para(all_level_links, possible_config) return config_list
[docs]@config_parser_register('_link') class LinkConfigParser(object): """LinkConfigParser""" def __init__(self, config_file): self.config = config_file assert isinstance(self.config, dict), f"The '_link' must be a dict, but you provide '{self.config}'"
[docs] def parser(self, parser_link=False): """parser the config config support _search and _link Args: parser_link: must be false Returns: all valided configs """ assert parser_link is False, f"The parser_link para must be False when parser the _link" return [self.config]
module_dir_map = { "task": "dlk/configures/tasks", "manager": "dlk/configures/managers", "callback": "dlk/configures/core/callbacks", "datamodule": "dlk/configures/data/datamodules", "imodel": "dlk/configures/core/imodels", "model": "dlk/configures/core/models", "optimizer": "dlk/configures/core/optimizers", "scheduler": "dlk/configures/core/schedulers", "initmethod": "dlk/configures/core/initmethods", "loss": "dlk/configures/core/losses", "encoder": "dlk/configures/core/layers/encoders", "decoder": "dlk/configures/core/layers/decoders", "embedding": "dlk/configures/core/layers/embeddings", "module": "dlk/configures/core/modules", "processor": "dlk/configures/data/processors", "subprocessor": "dlk/configures/data/subprocessors", "postprocessor": "dlk/configures/data/postprocessors", }
[docs]@config_parser_register('task') class TaskConfigParser(BaseConfigParser): """docstring for TaskConfigParser""" def __init__(self, config_file): super(TaskConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['task'])
[docs]@config_parser_register('root') class RootConfigParser(BaseConfigParser): """docstring for RootConfigParser""" def __init__(self, config_file): super(RootConfigParser, self).__init__(config_file, config_base_dir='')
[docs]@config_parser_register('manager') class ManagerConfigParser(BaseConfigParser): """docstring for ManagerConfigParser""" def __init__(self, config_file): super(ManagerConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['manager'], register=manager_config_register)
[docs]@config_parser_register('callback') class CallbackConfigParser(BaseConfigParser): """docstring for CallbackConfigParser""" def __init__(self, config_file): super(CallbackConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['callback'], register=callback_config_register)
[docs]@config_parser_register('datamodule') class DatamoduleConfigParser(BaseConfigParser): """docstring for DatamoduleConfigParser""" def __init__(self, config_file): super(DatamoduleConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['datamodule'], register=datamodule_config_register)
[docs]@config_parser_register('imodel') class IModelConfigParser(BaseConfigParser): """docstring for IModelConfigParser""" def __init__(self, config_file): super(IModelConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['imodel'], register=imodel_config_register)
[docs]@config_parser_register('model') class ModelConfigParser(BaseConfigParser): """docstring for ModelConfigParser""" def __init__(self, config_file): super(ModelConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['model'], register=model_config_register)
[docs]@config_parser_register('optimizer') class OptimizerConfigParser(BaseConfigParser): """docstring for OptimizerConfigParser""" def __init__(self, config_file): super(OptimizerConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['optimizer'], register=optimizer_config_register)
[docs]@config_parser_register('scheduler') class ScheduleConfigParser(BaseConfigParser): """docstring for ScheduleConfigParser""" def __init__(self, config_file): super(ScheduleConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['scheduler'], register=scheduler_config_register)
[docs]@config_parser_register('initmethod') class InitMethodConfigParser(BaseConfigParser): """docstring for InitMethodConfigParser""" def __init__(self, config_file): super(InitMethodConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['initmethod'], register=initmethod_config_register)
[docs]@config_parser_register('loss') class LossConfigParser(BaseConfigParser): """docstring for LossConfigParser""" def __init__(self, config_file): super(LossConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['loss'], register=loss_config_register)
[docs]@config_parser_register('encoder') class EncoderConfigParser(BaseConfigParser): """docstring for EncoderConfigParser""" def __init__(self, config_file): super(EncoderConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['encoder'], register=encoder_config_register)
[docs]@config_parser_register('decoder') class DecoderConfigParser(BaseConfigParser): """docstring for DecoderConfigParser""" def __init__(self, config_file): super(DecoderConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['decoder'], register=decoder_config_register)
[docs]@config_parser_register('embedding') class EmbeddingConfigParser(BaseConfigParser): """docstring for EmbeddingConfigParser""" def __init__(self, config_file): super(EmbeddingConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['embedding'], register=embedding_config_register)
[docs]@config_parser_register('module') class ModuleConfigParser(BaseConfigParser): """docstring for ModuleConfigParser""" def __init__(self, config_file): super(ModuleConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['module'], register=module_config_register)
[docs]@config_parser_register('processor') class ProcessorConfigParser(BaseConfigParser): """docstring for ProcessorConfigParser""" def __init__(self, config_file): super(ProcessorConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['processor'], register=processor_config_register)
[docs]@config_parser_register('subprocessor') class SubProcessorConfigParser(BaseConfigParser): """docstring for SubProcessorConfigParser""" def __init__(self, config_file): super(SubProcessorConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['subprocessor'], register=subprocessor_config_register)
[docs]@config_parser_register('postprocessor') class PostProcessorConfigParser(BaseConfigParser): """docstring for PostProcessorConfigParser""" def __init__(self, config_file): super(PostProcessorConfigParser, self).__init__(config_file, config_base_dir=module_dir_map['postprocessor'], register=postprocessor_config_register)