Source code for dlk.train

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hjson
import os
from typing import Dict, Union, Callable, List, Any
from dlk.utils.parser import BaseConfigParser
from dlk.utils.config import ConfigTool
from dlk.data.datamodules import datamodule_register, datamodule_config_register
from dlk.managers import manager_register, manager_config_register
from dlk.core.imodels import imodel_register, imodel_config_register
import pickle as pkl
from dlk.utils.io import open
import json
from dlk.utils.logger import Logger

logger = Logger.get_logger()


[docs]class Train(object): """Trainer Config Example: >>> { >>> "_focus": { >>> }, >>> "_link": {}, >>> "_search": {}, >>> "config": { >>> "save_dir": "*@*", # must be provided >>> "data_path": "*@*", # must be provided >>> }, >>> "task": { >>> "_name": task_name >>> ... >>> } >>> } """ def __init__(self, config: Union[str, Dict], ckpt: str = ""): super(Train, self).__init__() if not isinstance(config, dict): with open(config, 'r') as f: config = hjson.load(f, object_pairs_hook=dict) self.ckpt = ckpt self.focus = config.pop('_focus', {}) self.configs = BaseConfigParser(config).parser_with_check() if self.ckpt: assert len( self.configs ) == 1, f"Reuse the checkpoint(ckpt is not none), you must provide the (only one) config which generate the checkpoint." self.config_names = [] for possible_config in self.configs: config_name = [] for source, to in self.focus.items(): config_point = possible_config trace = source.split('.') for t in trace: config_point = config_point[t] config_name.append(to + str(config_point)) if config_name: self.config_names.append('_'.join(config_name)) else: self.config_names.append(possible_config['root']['_name']) if len(self.config_names) != len(set(self.config_names)): for config, name in zip(self.configs, self.config_names): logger.info( f"{name}:\n{json.dumps(config, indent=4, ensure_ascii=False)}" ) raise NameError('The config_names is not unique.')
[docs] def run(self): """run for all configs Returns: None """ logger.info( f"You have {len(self.config_names)} training config(s), they all will be run." ) for i, (config, name) in enumerate(zip(self.configs, self.config_names)): logger.info(f"Runing the {i}th {name}...") self.run_oneturn(config, name)
[docs] def dump_config(self, config: Dict, name: str): """dump the config and change the log file path to config['config']['save_dir']+name Args: config: {"config": {"save_dir": '..'}} name: config name Returns: None """ log_path = os.path.join(config.get('config').get('save_dir'), name) with open(os.path.join(config.get('config').get('save_dir'), name, "config.json"), 'w') as f: json.dump( { "root": config, "_focus": self.focus }, f, ensure_ascii=False, indent=4 ) Logger.init_file_logger("log.txt", log_path)
[docs] def run_oneturn(self, config, name): """run this config Args: config: {"root": '...'} name: config name Returns: None """ config = config['root'] # save configure self.dump_config(config, name) # get data data = self.get_data(config) # set datamodule datamodule = self.get_datamodule(config, data) # set training manager manager = self.get_manager(config, name) # init imodel and inject the origin test and valid data imodel = self.get_imodel(config, data) # start training manager.fit(model=imodel, datamodule=datamodule) manager.test(model=imodel, datamodule=datamodule)
[docs] def get_data(self, config): """get the data decided by config Args: config: {"config": {"data_path": '..'}} Returns: loaded data """ with open(config['config']['data_path'], 'rb') as f: self.data = pkl.load(f).get('data', {}) return self.data
[docs] def get_datamodule(self, config, data): """get the datamodule decided by config, and fit the data to datamodule Args: config: {"task": {"datamodule": '..'}} data: {"train": '..', 'valid': '..', ..} Returns: datamodule """ DataModule, DataModuleConfig = ConfigTool.get_leaf_module( datamodule_register, datamodule_config_register, 'datamodule', config['task']['datamodule']) datamodule = DataModule(DataModuleConfig, data) return datamodule
[docs] def get_manager(self, config, name): """get the tranin/predict manager decided by config Args: config: {"task": {"manager": '..'}, "config": {"save_dir"}} name: the predict progress name Returns: manager """ Manager, ManagerConfig = ConfigTool.get_leaf_module( manager_register, manager_config_register, 'manager', config.get('task').get('manager')) manager = Manager(ManagerConfig, rt_config={ "save_dir": config.get('config').get("save_dir"), "name": name }) return manager
[docs] def get_imodel(self, config, data): """get the imodel decided by config, and inject the origin test and valid data Args: config: {"task": {"imodel": '..'}} data: {"train": '..', 'valid': '..', ..} Returns: imodel """ IModel, IModelConfig = ConfigTool.get_leaf_module( imodel_register, imodel_config_register, 'imodel', config.get('task').get('imodel')) imodel = IModel(IModelConfig) if self.ckpt: logger.info(f"reuse the checkpoint at {self.ckpt}") imodel.load_from_checkpoint(self.ckpt) if 'valid' in data: imodel._origin_valid_data = data['valid'] if 'test' in data: imodel._origin_test_data = data['test'] return imodel