Source code for dlk.data.subprocessors.char_gather

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dlk.utils.vocab import Vocabulary
from dlk.utils.config import ConfigTool
from typing import Dict, Callable, Iterable, Set, List, Union
from dlk.data.subprocessors import subprocessor_register, subprocessor_config_register, ISubProcessor
from dlk.utils.logger import Logger
from dlk.utils.config import BaseConfig

logger = Logger.get_logger()

[docs]@subprocessor_config_register('char_gather') class CharGatherConfig(BaseConfig): """Config for CharGather Config Example: >>> { >>> "_name": "char_gather", >>> "config": { >>> "train": { >>> "data_set": { // for different stage, this processor will process different part of data >>> "train": ["train", "valid", 'test'] >>> }, >>> "gather_columns": "*@*", //List of columns. Every cell must be sigle token or list of tokens or set of tokens >>> "deliver": "char_vocab", // output Vocabulary object (the Vocabulary of labels) name. >>> "ignore": "", // ignore the token, the id of this token will be -1 >>> "update": null, // null or another Vocabulary object to update >>> "unk": "[UNK]", >>> "pad": "[PAD]", >>> "min_freq": 1, >>> "most_common": -1, //-1 for all >>> } >>> } >>> } """ def __init__(self, stage: str, config: Dict): super(CharGatherConfig, self).__init__(config) self.config = ConfigTool.get_config_by_stage(stage, config) self.data_set = self.config.get('data_set', {}).get(stage, []) if not self.data_set: return self.ignore = self.config['ignore'] self.gather_columns = self.config["gather_columns"] self.deliver = self.config["deliver"] if self.data_set and (not self.deliver): raise ValueError("The 'deliver' value must not be null.") self.update = self.config['update'] self.unk = self.config['unk'] self.min_freq = self.config['min_freq'] self.most_common = self.config['most_common'] self.post_check(self.config, used=[ "data_set", "gather_columns", "deliver", "ignore", "update", "unk", "pad", "min_freq", "most_common" ])
[docs]@subprocessor_register('char_gather') class CharGather(ISubProcessor): """gather all character from the 'gather_columns' and deliver a vocab named 'char_vocab' """ def __init__(self, stage: str, config: CharGatherConfig): super().__init__() self.stage = stage self.config = config self.data_set = config.data_set if not self.data_set: logger.info(f"Skip 'char_gather' at stage {self.stage}") return self.update = config.update
[docs] def split_to_char(self, input: Union[str, Iterable]): """the char is from token or sentence, so we need split them to List[char] Args: input: auto detach the type of input and split it to char Returns: the same shape of the input but the str is split to List[char] """ if isinstance(input, str): return [c for c in input] else: return [self.split_to_char(sub_input) for sub_input in input]
[docs] def process(self, data: Dict)->Dict: """Charactor gather entry Args: data: >>> { >>> "data": {"train": ...}, >>> "tokenizer": .. >>> } Returns: data[self.config.deliver] = Vocabulary()(which gathered_char) """ if not self.data_set: return data if self.update: self.vocab = data[self.update] else: self.vocab = Vocabulary(do_strip=True, unknown=self.config.unk, ignore=self.config.ignore) for data_set_name in self.data_set: if data_set_name not in data['data']: logger.info(f'The {data_set_name} not in data. We will skip gather tokens from it.') continue data_set = data['data'][data_set_name] for column in self.config.gather_columns: self.vocab.auto_update(self.split_to_char(data_set[column])) self.vocab.filter_rare(self.config.min_freq, self.config.most_common) logger.info(f"The Char Vocab Num is {self.vocab.word_num}") data[self.config.deliver] = self.vocab.__dict__ return data