Source code for dlk.data.subprocessors.token_gather

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dlk.utils.vocab import Vocabulary
from dlk.utils.config import BaseConfig, ConfigTool
from typing import Dict, Callable, Iterable, Set, List
from dlk.data.subprocessors import subprocessor_register, subprocessor_config_register, ISubProcessor
from dlk.utils.logger import Logger
import pandas as pd

logger = Logger.get_logger()

[docs]@subprocessor_config_register('token_gather') class TokenGatherConfig(BaseConfig): """Config for TokenGather Config Example: >>> { >>> "_name": "token_gather", >>> "config": { >>> "train": { >>> "data_set": { // for different stage, this processor will process different part of data >>> "train": ["train", "valid", 'test'] >>> }, >>> "gather_columns": "*@*", //List of columns, if one element of the list is dict, you can set more. Every cell must be sigle token or list of tokens or set of tokens >>> //"gather_columns": ['tokens'] >>> //"gather_columns": ['tokens', {"column": "entities_info", "trace": 'labels'}] >>> // the trace only trace the dict, if list is in trace path, will add the trace to every elements in the list. for example: {"entities_info": [{'start': 1, 'end': 2, labels: ['Label1']}, ..]}, the trace to labels is 'entities_info.labels' >>> "deliver": "*@*", // output Vocabulary object (the Vocabulary of labels) name. >>> "ignore": "", // ignore the token, the id of this token will be -1 >>> "update": null, // null or another Vocabulary object to update >>> "unk": "[UNK]", >>> "pad": "[PAD]", >>> "min_freq": 1, >>> "most_common": -1, //-1 for all >>> } >>> } >>> } """ def __init__(self, stage: str, config: Dict): super(TokenGatherConfig, self).__init__(config) self.config = ConfigTool.get_config_by_stage(stage, config) self.data_set = self.config.get('data_set', {}).get(stage, []) if not self.data_set: return self.ignore = self.config['ignore'] self.gather_columns = self.config["gather_columns"] self.deliver = self.config["deliver"] if self.data_set and (not self.deliver): raise ValueError("The 'deliver' value must not be null.") self.update = self.config['update'] self.unk = self.config['unk'] self.pad = self.config['pad'] self.min_freq = self.config['min_freq'] self.most_common = self.config['most_common'] self.post_check(self.config, used=[ "data_set", "gather_columns", "deliver", "ignore", "update", "unk", "pad", "min_freq", "most_common", ])
[docs]@subprocessor_register('token_gather') class TokenGather(ISubProcessor): """gather all tokens from the 'gather_columns' and deliver a vocab named 'token_vocab' """ def __init__(self, stage: str, config: TokenGatherConfig): super().__init__() self.stage = stage self.config = config self.data_set = config.data_set if not self.data_set: logger.info(f"Skip 'token_gather' at stage {self.stage}") return self.update = config.update
[docs] def get_elements_from_series_by_trace(self, data: pd.Series, trace: str)->List: """get the datas from data[trace_path] >>> for example: >>> data[0] = {'entities_info': [{'start': 0, 'end': 1, 'labels': ['Label1']}]} // data is a series, and every element is as data[0] >>> trace = 'entities_info.labels' >>> return_result = [['Label1']] Args: data: origin data series trace: get data element trace Returns: the data in the tail of traces """ def get_elements_from_iter_by_trace(iter: Iterable, cur_trace_list: List): if not cur_trace_list: return iter if isinstance(iter, dict): return get_elements_from_iter_by_trace(iter[cur_trace_list[0]], cur_trace_list[1:]) if isinstance(iter, list) or isinstance(iter, tuple): return [get_elements_from_iter_by_trace(sub_iter, cur_trace_list) for sub_iter in iter] raise PermissionError(f"The trace path is only support type list and dict, but you provide {type(iter)}") return [get_elements_from_iter_by_trace(one, trace.split('.')) for one in data]
[docs] def process(self, data: Dict)->Dict: """TokenGather entry Args: data: >>> { >>> "data": {"train": ...}, >>> "tokenizer": .. >>> } Returns: data[self.config.deliver] = Vocabulary()(which gathered_token) """ if not self.data_set: return data if self.update: self.vocab = data[self.update] else: self.vocab = Vocabulary(do_strip=True, unknown=self.config.unk, ignore=self.config.ignore, pad=self.config.pad) for data_set_name in self.data_set: if data_set_name not in data['data']: logger.info(f'The {data_set_name} not in data. We will skip gather tokens from it.') continue data_set = data['data'][data_set_name] for column in self.config.gather_columns: if isinstance(column, str): self.vocab.auto_update(data_set[column]) elif isinstance(column, dict): self.vocab.auto_update(self.get_elements_from_series_by_trace(data_set[column['column']], trace=column['trace'])) else: raise PermissionError(f'The gather column currently is only support str or dict.') self.vocab.filter_rare(self.config.min_freq, self.config.most_common) logger.info(f"The Vocab Num is {self.vocab.word_num}") data[self.config.deliver] = self.vocab.__dict__ return data