Source code for dlk.core.imodels

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""imodels"""
import importlib
import os
from typing import Callable, Dict, List, Tuple, Any
from dlk.utils.register import Register
import torch

imodel_config_register = Register("IModel config register.")
imodel_register = Register("IModel register.")

[docs]class GatherOutputMixin(object): """gather all the small batches output to a big batch"""
[docs] @staticmethod def proc_dist_outputs(dist_outputs: List[Dict])->List[Dict]: """gather all distributed outputs to outputs which is like in a single worker. Args: dist_outputs: the inputs of pytorch_lightning train/test/.._epoch_end when using ddp Returns: the inputs of pytorch_lightning train/test/.._epoch_end when only run on one worker. """ outputs = [] for dist_output in dist_outputs: one_output = {} for key in dist_output: try: one_output[key] = torch.cat(torch.swapaxes(dist_output[key], 0, 1).unbind(), dim=0) except: raise KeyError(f"{key}: {dist_output[key]}") outputs.append(one_output) return outputs
[docs] def gather_outputs(self, outputs: List[Dict]): """gather the dist outputs Args: outputs: one node outputs Returns: all outputs """ if self.trainer.world_size>1: dist_outputs = self.all_gather(outputs) if self.local_rank in [0, -1]: outputs = self.proc_dist_outputs(dist_outputs) return outputs
[docs] def concat_list_of_dict_outputs(self, outputs: List[Dict])->Dict: """only support all the outputs has the same dim, now is deprecated. Args: outputs: multi node returned output (list of dict) Returns: Concat all list by name """ key_all_batch_map = {} for batch in outputs: for key in batch: if key not in key_all_batch_map: key_all_batch_map[key] = [] key_all_batch_map[key].append(batch[key]) key_all_ins_map = {} for key in key_all_batch_map: key_all_ins_map[key] = torch.cat(key_all_batch_map[key], dim=0) return key_all_ins_map
[docs]def import_imodels(imodels_dir, namespace): for file in os.listdir(imodels_dir): path = os.path.join(imodels_dir, file) if ( not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)) ): imodel_name = file[: file.find(".py")] if file.endswith(".py") else file importlib.import_module(namespace + "." + imodel_name)
# automatically import any Python files in the imodels directory imodels_dir = os.path.dirname(__file__) import_imodels(imodels_dir, "dlk.core.imodels")