dlk.data.datamodules package

Submodules

dlk.data.datamodules.basic module

class dlk.data.datamodules.basic.BasicDatamodule(config: dlk.data.datamodules.basic.BasicDatamoduleConfig, data: Dict[str, Any])[source]

Bases: dlk.data.datamodules.IBaseDataModule

Basic and General DataModule

online_dataloader()[source]

get the data collate_fn

predict_dataloader()[source]

get the predict set dataloader

real_key_type_pairs(key_type_pairs: Dict, data: Dict, field: str)[source]

return the keys = key_type_pairs.keys() ∩ data.columns

Parameters
  • key_type_pairs – data in columns should map to tensor type

  • data – the pd.DataFrame

  • field – traing/valid/test, etc.

Returns

real_key_type_pairs where keys = key_type_pairs.keys() ∩ data.columns

test_dataloader()[source]

get the test set dataloader

train_dataloader()[source]

get the train set dataloader

val_dataloader()[source]

get the validation set dataloader

class dlk.data.datamodules.basic.BasicDatamoduleConfig(config)[source]

Bases: dlk.utils.config.BaseConfig

Config for BasicDatamodule

Config Example:
>>> {
>>>     "_name": "basic",
>>>     "config": {
>>>         "pin_memory": None,
>>>         "collate_fn": "default",
>>>         "num_workers": null,
>>>         "shuffle": {
>>>             "train": true,
>>>             "predict": false,
>>>             "valid": false,
>>>             "test": false,
>>>             "online": false
>>>         },
>>>         "key_type_pairs": {
>>>              'input_ids': 'int',
>>>              'label_ids': 'long',
>>>              'type_ids': 'long',
>>>          },
>>>         "gen_mask": {
>>>              'input_ids': 'attention_mask',
>>>          },
>>>         "key_padding_pairs": { //default all 0
>>>              'input_ids': 0,
>>>          },
>>>         "key_padding_pairs_2d": { //default all 0, for 2 dimension data
>>>              'input_ids': 0,
>>>          },
>>>         "train_batch_size": 32,
>>>         "predict_batch_size": 32, //predict、test batch_size is equals to valid_batch_size
>>>         "online_batch_size": 1,
>>>     }
>>> },
class dlk.data.datamodules.basic.BasicDataset(key_type_pairs: Dict[str, str], data: pandas.core.frame.DataFrame)[source]

Bases: torch.utils.data.dataset.Dataset

Basic and General Dataset

Module contents

datamodules

class dlk.data.datamodules.DefaultCollate(**config)[source]

Bases: object

docstring for DefaultCollate

class dlk.data.datamodules.IBaseDataModule[source]

Bases: pytorch_lightning.core.datamodule.LightningDataModule

docstring for IBaseDataModule

abstract online_dataloader()[source]
Raises

NotImplementedError

predict_dataloader()[source]
Raises

NotImplementedError

test_dataloader()[source]
Raises

NotImplementedError

train_dataloader()[source]
Raises

NotImplementedError

val_dataloader()[source]
Raises

NotImplementedError

dlk.data.datamodules.import_datamodules(datamodules_dir, namespace)[source]