# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import model_register, model_config_register
from typing import Dict, List
from dlk.utils.config import BaseConfig, ConfigTool
from dlk.core.base_module import BaseModel
import torch
from dlk.core.layers.embeddings import embedding_config_register, embedding_register
from dlk.core.initmethods import initmethod_config_register, initmethod_register
from dlk.core.layers.encoders import encoder_config_register, encoder_register
from dlk.core.layers.decoders import decoder_config_register, decoder_register
[docs]@model_config_register('basic')
class BasicModelConfig(BaseConfig):
"""Config for BasicModel
Config Example:
>>> {
>>> embedding: {
>>> _base: "static"
>>> config: {
>>> embedding_file: "*@*", //the embedding file, must be saved as numpy array by pickle
>>> embedding_dim: "*@*",
>>> //if the embedding_file is a dict, you should provide the dict trace to embedding
>>> embedding_trace: ".", //default the file itself is the embedding
>>> /*embedding_trace: "embedding", //this means the <embedding = pickle.load(embedding_file)["embedding"]>*/
>>> /*embedding_trace: "meta.embedding", //this means the <embedding = pickle.load(embedding_file)['meta']["embedding"]>*/
>>> freeze: false, // is freeze
>>> dropout: 0, //dropout rate
>>> output_map: {},
>>> },
>>> },
>>> decoder: {
>>> _base: "linear",
>>> config: {
>>> input_size: "*@*",
>>> output_size: "*@*",
>>> pool: null,
>>> dropout: "*@*", //the decoder output no need dropout
>>> output_map: {}
>>> },
>>> },
>>> encoder: {
>>> _base: "lstm",
>>> config: {
>>> output_map: {},
>>> hidden_size: "*@*",
>>> input_size: *@*,
>>> output_size: "*@*",
>>> num_layers: 1,
>>> dropout: "*@*", // dropout between layers
>>> },
>>> },
>>> "initmethod": {
>>> "_base": "range_norm"
>>> },
>>> "config": {
>>> "embedding_dim": "*@*",
>>> "dropout": "*@*",
>>> "embedding_file": "*@*",
>>> "embedding_trace": "token_embedding",
>>> },
>>> _link: {
>>> "config.embedding_dim": ["embedding.config.embedding_dim",
>>> "encoder.config.input_size",
>>> "encoder.config.output_size",
>>> "encoder.config.hidden_size",
>>> "decoder.config.output_size",
>>> "decoder.config.input_size"
>>> ],
>>> "config.dropout": ["encoder.config.dropout", "decoder.config.dropout", "embedding.config.dropout"],
>>> "config.embedding_file": ['embedding.config.embedding_file'],
>>> "config.embedding_trace": ['embedding.config.embedding_trace']
>>> }
>>> _name: "basic"
>>> }
"""
def __init__(self, config):
super(BasicModelConfig, self).__init__(config)
self.embedding, self.embedding_config = self.get_embedding(config["embedding"])
self.encoder, self.encoder_config = self.get_encoder(config["encoder"])
self.decoder, self.decoder_config = self.get_decoder(config["decoder"])
self.init_method, self.init_method_config = self.get_init_method(config["initmethod"])
self.config = config['config']
[docs] def get_embedding(self, config: Dict):
"""return the Embedding and EmbeddingConfig
Args:
config: the embedding config
Returns:
Embedding, EmbeddingConfig
"""
return ConfigTool.get_leaf_module(embedding_register, embedding_config_register, "embedding", config)
[docs] def get_init_method(self, config: Dict):
"""return the InitMethod and InitMethodConfig
Args:
config: the init method config
Returns:
InitMethod, InitMethodConfig
"""
return ConfigTool.get_leaf_module(initmethod_register, initmethod_config_register, "init method", config)
[docs] def get_encoder(self, config: Dict):
"""return the Encoder and EncoderConfig
Args:
config: the encoder config
Returns:
Encoder, EncoderConfig
"""
return ConfigTool.get_leaf_module(encoder_register, encoder_config_register, "encoder", config)
[docs] def get_decoder(self, config):
"""return the Decoder and DecoderConfig
Args:
config: the decoder config
Returns:
Decoder, DecoderConfig
"""
return ConfigTool.get_leaf_module(decoder_register, decoder_config_register, "decoder", config)
[docs]@model_register('basic')
class BasicModel(BaseModel):
"""Basic & General Model
"""
def __init__(self, config: BasicModelConfig, checkpoint):
super().__init__()
self.embedding = config.embedding(config.embedding_config)
self.encoder = config.encoder(config.encoder_config)
self.decoder = config.decoder(config.decoder_config)
if not checkpoint:
init_method = config.init_method(config.init_method_config)
self.embedding.init_weight(init_method)
self.encoder.init_weight(init_method)
self.decoder.init_weight(init_method)
self.config = config.config
self._provided_keys = self.config.get("provided_keys", [])
[docs] def provide_keys(self)->List[str]:
"""return all keys of the dict of the model returned
This method may no use, so we will remove this.
Returns: all keys
"""
return self.decoder.provided_keys()
[docs] def check_keys_are_provided(self, provide: List[str]=[])->None:
"""check this all the submodules required key are provided
Returns: None
Raises: PermissionError
"""
self._provided_keys = self._provided_keys + provide
self.embedding.check_keys_are_provided(self._provided_keys)
self.encoder.check_keys_are_provided(self.embedding.provide_keys())
self.decoder.check_keys_are_provided(self.encoder.provide_keys())
[docs] def forward(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]:
"""do forward on a mini batch
Args:
batch: a mini batch inputs
Returns: the outputs
"""
embedding_outputs = self.embedding(inputs)
encode_outputs = self.encoder(embedding_outputs)
decode_outputs = self.decoder(encode_outputs)
return decode_outputs
[docs] def predict_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]:
"""do predict for one batch
Args:
inputs: one mini-batch inputs
Returns: the predicts outputs
"""
embedding_outputs = self.embedding.predict_step(inputs)
encode_outputs = self.encoder.predict_step(embedding_outputs)
decode_outputs = self.decoder.predict_step(encode_outputs)
return decode_outputs
[docs] def training_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]:
"""do training for one batch
Args:
inputs: one mini-batch inputs
Returns: the training outputs
"""
embedding_outputs = self.embedding.training_step(inputs)
encode_outputs = self.encoder.training_step(embedding_outputs)
decode_outputs = self.decoder.training_step(encode_outputs)
return decode_outputs
[docs] def validation_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]:
"""do validation for one batch
Args:
inputs: one mini-batch inputs
Returns: the validation outputs
"""
embedding_outputs = self.embedding.validation_step(inputs)
encode_outputs = self.encoder.validation_step(embedding_outputs)
decode_outputs = self.decoder.validation_step(encode_outputs)
return decode_outputs
[docs] def test_step(self, inputs: Dict[str, torch.Tensor])->Dict[str, torch.Tensor]:
"""do test for one batch
Args:
inputs: one mini-batch inputs
Returns: the test outputs
"""
embedding_outputs = self.embedding.test_step(inputs)
encode_outputs = self.encoder.test_step(embedding_outputs)
decode_outputs = self.decoder.test_step(encode_outputs)
return decode_outputs