Source code for dlk.data.datamodules

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""datamodules"""

import importlib
import os
from typing import Dict, Any
from dlk.utils.register import Register
from pytorch_lightning import LightningDataModule
import abc
from torch.nn.utils.rnn import pad_sequence
import torch

datamodule_config_register = Register("Datamodule config register")
datamodule_register = Register("Datamodule register")


collate_register = Register('Collate function register.')



[docs]@collate_register('default') class DefaultCollate(object): """docstring for DefaultCollate""" def __init__(self, **config): super(DefaultCollate, self).__init__() self.key_padding_pairs = config.get("key_padding_pairs", {}) self.key_padding_pairs_2d = config.get("key_padding_pairs_2d", {}) self.gen_mask = config.get("gen_mask", {}) def __call__(self, batch): keys = batch[0].keys() data_map: Dict[str, Any] = {} for key in keys: data_map[key] = [] for key in keys: for one_ins in batch: data_map[key].append(one_ins[key]) if self.gen_mask: for key, mask in self.gen_mask.items(): data_map[mask] = [] for item in data_map[key]: data_map[mask].append(torch.tensor([1] * len(item), dtype=torch.int)) for key in data_map: if key in self.key_padding_pairs_2d: max_m, max_n = 0, 0 for ins in data_map[key]: cur_m, cur_n = ins.shape max_m = max(max_m, cur_m) max_n = max(max_n, cur_n) _data = torch.full((len(data_map[key]), max_m, max_n), fill_value=self.key_padding_pairs_2d.get(key, -1), dtype=data_map[key][0].dtype) for i, ins in enumerate(data_map[key]): cur_m, cur_n = ins.shape _data[i][:cur_m,:cur_n] = ins data_map[key] = _data else: try: data_map[key] = pad_sequence(data_map[key], batch_first=True, padding_value=self.key_padding_pairs.get(key, 0)) except: # if the data_map[key] is size 0, we can concat them if data_map[key][0].size(): raise ValueError(f"The {data_map[key]} can not be concat by pad_sequence.") _data = pad_sequence([i.unsqueeze(0) for i in data_map[key]], batch_first=True, padding_value=self.key_padding_pairs.get(key, 0)).squeeze() if not _data.size(): _data.unsqueeze_(0) data_map[key] = _data return data_map
[docs]class IBaseDataModule(LightningDataModule): """docstring for IBaseDataModule""" def __init__(self): super(IBaseDataModule, self).__init__()
[docs] def train_dataloader(self): """ Raises: NotImplementedError """ raise NotImplementedError(f"You must implementation the train_dataloader for your own datamodule.")
[docs] def predict_dataloader(self): """ Raises: NotImplementedError """ raise NotImplementedError(f"You must implementation the predict_dataloader for your own datamodule.")
[docs] def val_dataloader(self): """ Raises: NotImplementedError """ raise NotImplementedError(f"You must implementation the val_dataloader for your own datamodule.")
[docs] def test_dataloader(self): """ Raises: NotImplementedError """ raise NotImplementedError(f"You must implementation the test_dataloader for your own datamodule.")
[docs] @abc.abstractmethod def online_dataloader(self): """ Raises: NotImplementedError """ raise NotImplementedError(f"You must implementation the online_dataloader for your own datamodule.")
[docs]def import_datamodules(datamodules_dir, namespace): for file in os.listdir(datamodules_dir): path = os.path.join(datamodules_dir, file) if ( not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)) ): datamodule_name = file[: file.find(".py")] if file.endswith(".py") else file importlib.import_module(namespace + "." + datamodule_name)
# automatically import any Python files in the models directory datamodules_dir = os.path.dirname(__file__) import_datamodules(datamodules_dir, "dlk.data.datamodules")