Source code for dlk.managers.lightning

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.functional import Tensor
import torch.nn as nn
from . import manager_register, manager_config_register
from typing import Dict, List
import hjson
import pytorch_lightning as pl
from dlk.utils.config import BaseConfig, ConfigTool
from dlk.utils.get_root import get_root
import os
from pytorch_lightning.callbacks import ModelCheckpoint
from dlk.core.callbacks import callback_register, callback_config_register
from dlk.utils.io import open
import dlk.utils.parser as parser
from pytorch_lightning.loggers import TensorBoardLogger


[docs]@manager_config_register('lightning') class LightningManagerConfig(BaseConfig): """docstring for LightningManagerConfig check https://pytorch-lightning.readthedocs.io trainer for paramaters detail """ def __init__(self, config): super(LightningManagerConfig, self).__init__(config) manager_config = config.get('config') self.callbacks = self.get_callbacks_config(config) # this is callback config, should be initialized Callback in LightningManager self.logger = manager_config["logger"] # True self.enable_checkpointing = manager_config["enable_checkpointing"] # True use checkpoint callbac self.accelerator = manager_config["accelerator"] # None self.default_root_dir = manager_config["default_root_dir"] # None self.gradient_clip_val = manager_config["gradient_clip_val"] # None self.gradient_clip_algorithm = manager_config["gradient_clip_algorithm"] # None TODO: ? default = 'norm', can select 'norm' or 'value self.num_nodes = manager_config["num_nodes"] # 1 self.devices = manager_config["devices"] # None self.auto_select_gpus = manager_config["auto_select_gpus"] # False self.ipus = manager_config["ipus"] # None self.log_gpu_memory = manager_config["log_gpu_memory"] # None self.enable_progress_bar = manager_config["enable_progress_bar"] # True self.overfit_batches = eval(manager_config["overfit_batches"]) self.track_grad_norm = manager_config["track_grad_norm"] # -1 self.check_val_every_n_epoch = manager_config["check_val_every_n_epoch"] # 1 self.fast_dev_run = manager_config["fast_dev_run"] # False self.accumulate_grad_batches = manager_config["accumulate_grad_batches"] # 1 self.max_epochs = manager_config["max_epochs"] # None self.min_epochs = manager_config["min_epochs"] # None self.max_steps = manager_config["max_steps"] # -1 self.min_steps = manager_config["min_steps"] # None self.max_time = manager_config["max_time"] # None self.limit_train_batches = eval(manager_config["limit_train_batches"]) self.limit_val_batches = eval(manager_config["limit_val_batches"]) self.limit_test_batches = eval(manager_config["limit_test_batches"]) self.limit_predict_batches = eval(manager_config["limit_predict_batches"]) self.val_check_interval = eval(manager_config["val_check_interval"]) self.log_every_n_steps = manager_config["log_every_n_steps"] # 50 self.strategy = manager_config["strategy"] # 'ddp' use ddp as default self.sync_batchnorm = manager_config["sync_batchnorm"] # False self.precision = manager_config["precision"] # 32 self.enable_model_summary = manager_config["enable_model_summary"] # True self.weights_summary = manager_config["weights_summary"] # 'top' self.weights_save_path = manager_config["weights_save_path"] # None self.num_sanity_val_steps = manager_config["num_sanity_val_steps"] # 2 self.resume_from_checkpoint = manager_config["resume_from_checkpoint"] # None self.profiler = manager_config["profiler"] # None 'simple', 'pytorch', etc self.benchmark = manager_config["benchmark"] # False self.deterministic = manager_config["deterministic"] # False self.reload_dataloaders_every_n_epochs = manager_config["reload_dataloaders_every_n_epochs"] # 0 self.auto_lr_find = manager_config["auto_lr_find"] # False self.replace_sampler_ddp = manager_config["replace_sampler_ddp"] # True self.detect_anomaly = manager_config["detect_anomaly"] # False self.auto_scale_batch_size = manager_config["auto_scale_batch_size"] # False self.plugins = manager_config["plugins"] # None TODO: add plugins from parser plugins config self.amp_backend = manager_config["amp_backend"] # 'native' pytorch>1.6 self.amp_level = manager_config["amp_level"] # None if not set amp_backend to "apex", don't need change this self.move_metrics_to_cpu = manager_config["move_metrics_to_cpu"] # False self.multiple_trainloader_mode = manager_config["multiple_trainloader_mode"] # 'max_size_cycle' self.stochastic_weight_avg = manager_config["stochastic_weight_avg"] # False self.terminate_on_nan = manager_config["terminate_on_nan"] # None self.post_check(manager_config, used=[ "callbacks", "logger", "enable_checkpointing", "accelerator", "default_root_dir", "gradient_clip_val", "gradient_clip_algorithm", "num_nodes", "devices", "auto_select_gpus", "ipus", "log_gpu_memory", "enable_progress_bar", "overfit_batches", "track_grad_norm", "check_val_every_n_epoch", "fast_dev_run", "accumulate_grad_batches", "max_epochs", "min_epochs", "max_steps", "min_steps", "max_time", "limit_train_batches", "limit_val_batches", "limit_test_batches", "limit_predict_batches", "val_check_interval", "log_every_n_steps", "strategy", "sync_batchnorm", "precision", "enable_model_summary", "weights_summary", "weights_save_path", "num_sanity_val_steps", "resume_from_checkpoint", "profiler", "benchmark", "deterministic", "reload_dataloaders_every_n_epochs", "auto_lr_find", "replace_sampler_ddp", "detect_anomaly", "auto_scale_batch_size", "plugins", "amp_backend", "amp_level", "move_metrics_to_cpu", "multiple_trainloader_mode", "stochastic_weight_avg", "terminate_on_nan", ])
[docs] def get_callbacks_config(self, config: Dict)->List[Dict]: """get the configs for callbacks Args: config: {"config": {"callbacks": ["callback_names"..]}, "callback@callback_names": {config}} Returns: configs which name in config['config']['callbacks'] """ callback_names = config.get("config", {}).get("callbacks", []) callback_configs_list = [] for callback_name in callback_names: callback_config = config.get(f"callback@{callback_name}", {}) if not callback_config: with open(os.path.join(get_root(), f'dlk/configures/core/callbacks/{callback_name}.hjson'), 'r') as f: callback_config = hjson.load(f, object_pairs_hook=dict) parser_callback_config = parser.config_parser_register.get('callback')(callback_config).parser_with_check(parser_link=False) assert len(parser_callback_config) == 1, f"Don't support multi callback config for one callback." callback_config = parser_callback_config[0] assert not callback_config.get("_link", {}), f"Callback don't support _link" callback_configs_list.append(callback_config) return callback_configs_list
[docs]@manager_register('lightning') class LightningManager(object): """pytorch-lightning traning manager """ def __init__(self, config: LightningManagerConfig, rt_config: Dict): super().__init__() if config.logger: config.logger = TensorBoardLogger(save_dir=os.path.join(rt_config["save_dir"], rt_config["name"]), version='') if config.callbacks: config.callbacks = self.get_callbacks(config.callbacks, rt_config) config.__dict__.pop('_name') self.manager = pl.Trainer(**config.__dict__)
[docs] def get_callbacks(self, callback_configs: List[Dict], rt_config: Dict): """init the callbacks and return the callbacks list Args: callback_configs: the config of every callback rt_config: {"save_dir": '..', "name": '..'} Returns: all callbacks """ callbacks_list = [] for callback_config in callback_configs: Callback, CallbackConfig = ConfigTool.get_leaf_module(callback_register, callback_config_register, "callback", callback_config) callbacks_list.append(Callback(CallbackConfig)(rt_config=rt_config)) return callbacks_list
[docs] def fit(self, **inputs): """fit the model and datamodule to trainer Args: **inputs: dict of input, include "model", 'datamodule' Returns: Undefine """ return self.manager.fit(**inputs)
[docs] def predict(self, **inputs): """fit the model and datamodule.predict_dataloader to predict Args: **inputs: dict of input, include "model", 'datamodule' Returns: predict list """ return self.manager.predict(**inputs)
[docs] def test(self, **inputs): """fit the model and datamodule.test_dataloader to test Args: **inputs: dict of input, include "model", 'datamodule' Returns: Undefine """ return self.manager.test(**inputs)
[docs] def validate(self, **inputs): """fit the model and datamodule.validation to validate Args: **inputs: dict of input, include "model", 'datamodule' Returns: Undefine """ return self.manager.validate(**inputs)