Source code for dlk.data.subprocessors.token_norm

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import PercentStyle
from dlk.utils.vocab import Vocabulary
from dlk.utils.config import BaseConfig, ConfigTool
from typing import Dict, Callable, Set, List
from dlk.data.subprocessors import subprocessor_register, subprocessor_config_register, ISubProcessor
from functools import partial
import pandas as pd
from dlk.utils.logger import Logger
from tokenizers.models import WordLevel
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import WhitespaceSplit

logger = Logger.get_logger()

[docs]@subprocessor_config_register('token_norm') class TokenNormConfig(BaseConfig): """Config for TokenNorm Config Example: >>> { >>> "_name": "token_norm", >>> "config": { >>> "train":{ >>> "data_set": { // for different stage, this processor will process different part of data >>> "train": ['train', 'valid', 'test', 'predict'], >>> "predict": ['predict'], >>> "online": ['online'] >>> }, >>> "zero_digits_replaced": true, >>> "lowercase": true, >>> "extend_vocab": "", //when lowercase is true, this upper_case_vocab will collection all tokens the token is not in vocab but it's lowercase is in vocab. this is only for token gather process >>> "tokenizer": "whitespace_split", //the path to vocab(if the token in vocab skip norm it), the file is setted to one token per line >>> "data_pair": { >>> "sentence": "norm_sentence" >>> }, >>> }, >>> "predict": "train", >>> "online": "train", >>> } >>> } """ def __init__(self, stage, config: Dict): super(TokenNormConfig, self).__init__(config) self.config = ConfigTool.get_config_by_stage(stage, config) self.data_set = self.config.get('data_set', {}).get(stage, []) if not self.data_set: return self.data_pair = self.config.pop('data_pair', {}) self.zero_digits_replaced = self.config.pop('zero_digits_replaced', True) self.lowercase = self.config.pop('lowercase', True) self.tokenizer = Tokenizer.from_file(self.config['tokenizer']) self.unk = self.tokenizer.model.unk_token self.vocab = self.tokenizer.get_vocab() self.do_extend_vocab = self.config['extend_vocab'] self.prefix = self.tokenizer.model.continuing_subword_prefix self.post_check(self.config, used=[ "data_set", "zero_digits_replaced", "lowercase", "extend_vocab", "tokenizer", "data_pair", ])
[docs] def tokenize(self, seq): """tokenize the seq """ encode = self.tokenizer.encode(seq) return encode
[docs]@subprocessor_register('token_norm') class TokenNorm(ISubProcessor): """ This part could merged to fast_tokenizer(it will save some time), but not all process need this part(except some special dataset like conll2003), and will make the fast_tokenizer be heavy. Token norm: Love -> love 3281 -> 0000 """ def __init__(self, stage: str, config: TokenNormConfig): super().__init__() self.stage = stage self.config = config self.data_set = config.data_set if not self.data_set: logger.info(f"Skip 'token_norm' at stage {self.stage}") return if self.config.do_extend_vocab: self.extend_vocab = set() self._zero_digits_replaced_num = 0 self._lower_case_num = 0 self._lower_case_zero_digits_replaced_num = 0
[docs] def token_norm(self, token: str)->str: """norm token, the result len(result) == len(token), exp. 12348->00000 Args: token: origin token Returns: normed_token """ if token in self.config.vocab: return token if self.config.zero_digits_replaced: norm = '' digit_num = 0 for c in token: if c.isdigit() or c=='.': norm += '0' digit_num += 1 else: norm += c if norm in self.config.vocab or self.config.prefix+norm in self.config.vocab: self._zero_digits_replaced_num += 1 return norm elif self.config.do_extend_vocab and digit_num == len(token) and digit_num<20: self.extend_vocab.add(norm) if self.config.lowercase: norm = token.lower() if norm in self.config.vocab or self.config.prefix+norm in self.config.vocab: self._lower_case_num += 1 if self.config.do_extend_vocab: self.extend_vocab.add(token) return norm if self.config.lowercase and self.config.zero_digits_replaced: norm = '' for c in token.lower(): if c.isdigit() or c=='.': norm += '0' else: norm += c if norm in self.config.vocab or self.config.prefix+norm in self.config.vocab: self._lower_case_zero_digits_replaced_num += 1 return norm return ''
[docs] def seq_norm(self, key:str, one_item: pd.Series)->str: """norm a sentence, the sentence is from one_item[key] Args: key: the name in one_item one_item: a pd.Series which include the key Returns: norm_sentence """ seq = one_item[key] norm_seq = [c for c in seq] encode = self.config.tokenize(seq) for i, token in enumerate(encode.tokens): if token == self.config.unk: token_offset = encode.offsets[i] prenorm_token = seq[token_offset[0]: token_offset[1]] norm_token = self.token_norm(prenorm_token) if not norm_token: continue assert len(norm_token) == token_offset[1] - token_offset[0], f"Prenorm '{prenorm_token}', postnorm: '{norm_token}' and {len(norm_token)}!= {token_offset[1]} - {token_offset[0]}" norm_seq[token_offset[0]: token_offset[1]] = norm_token return ''.join(norm_seq)
[docs] def process(self, data: Dict)->Dict: """TokenNorm entry Args: data: { "data": {"train": ...}, "tokenizer": .. } Returns: norm data """ if not self.data_set: return data for data_set_name in self.data_set: if data_set_name not in data['data']: logger.info(f'The {data_set_name} not in data. We will skip do token_norm on it.') continue data_set = data['data'][data_set_name] for key, value in self.config.data_pair.items(): _seq_norm = partial(self.seq_norm, key) data_set[value] = data_set.apply(_seq_norm, axis=1) # WARNING: if you change the apply to parallel_apply, you should change the _zero_digits_replaced_num, etc. to multiprocess safely(BTW, use parallel_apply in tokenizers==0.10.3 will make the process very slow) # data_set[value] = data_set.apply(_seq_norm, axis=1) logger.info(f"We use zero digits to replace digit token num is {self._zero_digits_replaced_num}, do lowercase token num is {self._lower_case_num}, do both num is {self._lower_case_zero_digits_replaced_num}") if self.config.do_extend_vocab: logger.info(f"We will extend {len(self.extend_vocab)} tokens and deliver to {self.config.do_extend_vocab}") data[self.config.do_extend_vocab] = list(self.extend_vocab) return data