# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle as pkl
import json
from typing import Dict, List, Optional, Tuple, Union
import os
import numpy as np
import pandas as pd
import torch
from typing import Dict
from dlk.data.postprocessors import postprocessor_register, postprocessor_config_register, IPostProcessor, IPostProcessorConfig
from dlk.utils.logger import Logger
from dlk.utils.vocab import Vocabulary
from tokenizers import Tokenizer
from dlk.utils.io import open
import torchmetrics
logger = Logger.get_logger()
[docs]@postprocessor_config_register('seq_lab')
class SeqLabPostProcessorConfig(IPostProcessorConfig):
"""Config for SeqLabPostProcessor
Config Example:
>>> {
>>> "_name": "seq_lab",
>>> "config": {
>>> "meta": "*@*",
>>> "use_crf": false, //use or not use crf
>>> "word_ready": false, //already gather the subword first token as the word rep or not
>>> "ignore_position": true, // calc the metrics, whether ignore the ground_truth and predict position info.( if set to true, only focus on the entity content not position.)
>>> "ignore_char": " ", // if the entity begin or end with this char, will ignore these char
>>> //"ignore_char": " ()[]-.,:", // if the entity begin or end with this char, will ignore these char
>>> "meta_data": {
>>> "label_vocab": 'label_vocab',
>>> "tokenizer": "tokenizer",
>>> },
>>> "input_map": {
>>> "logits": "logits",
>>> "predict_seq_label": "predict_seq_label",
>>> "_index": "_index",
>>> },
>>> "origin_input_map": {
>>> "uuid": "uuid",
>>> "sentence": "sentence",
>>> "input_ids": "input_ids",
>>> "entities_info": "entities_info",
>>> "offsets": "offsets",
>>> "special_tokens_mask": "special_tokens_mask",
>>> "word_ids": "word_ids",
>>> "label_ids": "label_ids",
>>> },
>>> "save_root_path": ".", //save data root dir
>>> "save_path": {
>>> "valid": "valid", // relative dir for valid stage
>>> "test": "test", // relative dir for test stage
>>> },
>>> "start_save_step": 0, // -1 means the last
>>> "start_save_epoch": -1,
>>> "aggregation_strategy": "max", // AggregationStrategy item
>>> "ignore_labels": ['O', 'X', 'S', "E"], // Out, Out, Start, End
>>> }
>>> }
"""
def __init__(self, config: Dict):
super(SeqLabPostProcessorConfig, self).__init__(config)
self.use_crf = self.config['use_crf']
self.word_ready = self.config['word_ready']
self.aggregation_strategy = self.config['aggregation_strategy']
self.ignore_labels = set(self.config['ignore_labels'])
self.ignore_char = set(self.config['ignore_char'])
self.ignore_position = self.config['ignore_position']
self.sentence = self.origin_input_map['sentence']
self.offsets = self.origin_input_map['offsets']
self.entities_info = self.origin_input_map['entities_info']
self.uuid = self.origin_input_map['uuid']
self.word_ids = self.origin_input_map['word_ids']
self.special_tokens_mask = self.origin_input_map['special_tokens_mask']
self.input_ids = self.origin_input_map['input_ids']
self.label_ids = self.origin_input_map['label_ids']
self.logits = self.input_map['logits']
self.predict_seq_label = self.input_map['predict_seq_label']
self._index = self.input_map['_index']
if isinstance(self.config['meta'], str):
with open(self.config['meta'], 'rb') as f:
meta = pkl.load(f)
else:
raise PermissionError("You must provide meta data(vocab & tokenizer) for ner postprocess.")
vocab_trace_path = []
vocab_trace_path_str = self.config['meta_data']['label_vocab']
if vocab_trace_path_str and vocab_trace_path_str.strip()!='.':
vocab_trace_path = vocab_trace_path_str.split('.')
assert vocab_trace_path, "We need vocab and tokenizer all in meta, so you must provide the trace path from meta"
self.label_vocab = meta
for trace in vocab_trace_path:
self.label_vocab = self.label_vocab[trace]
self.label_vocab = Vocabulary.load(self.label_vocab)
tokenizer_trace_path = []
tokenizer_trace_path_str = self.config['meta_data']['tokenizer']
if tokenizer_trace_path_str and tokenizer_trace_path_str.strip()!='.':
tokenizer_trace_path = tokenizer_trace_path_str.split('.')
assert tokenizer_trace_path, "We need vocab and tokenizer all in meta, so you must provide the trace path from meta"
tokenizer_config = meta
for trace in tokenizer_trace_path:
tokenizer_config = tokenizer_config[trace]
self.tokenizer = Tokenizer.from_str(tokenizer_config)
self.save_path = self.config['save_path']
self.save_root_path = self.config['save_root_path']
self.start_save_epoch = self.config['start_save_epoch']
self.start_save_step = self.config['start_save_step']
self.post_check(self.config, used=[
"meta",
"use_crf",
"word_ready",
"meta_data",
"input_map",
"origin_input_map",
"save_root_path",
"save_path",
"start_save_step",
"start_save_epoch",
"aggregation_strategy",
"ignore_labels",
"ignore_char",
"ignore_position",
])
[docs]class AggregationStrategy(object):
"""docstring for AggregationStrategy"""
NONE = "none"
SIMPLE = "simple"
FIRST = "first"
AVERAGE = "average"
MAX = "max"
[docs]@postprocessor_register('seq_lab')
class SeqLabPostProcessor(IPostProcessor):
"""PostProcess for sequence labeling task"""
def __init__(self, config: SeqLabPostProcessorConfig):
super( SeqLabPostProcessor, self).__init__()
self.config = config
self.label_vocab = self.config.label_vocab
self.tokenizer = self.config.tokenizer
self.metric = torchmetrics.Accuracy()
[docs] def do_predict(self, stage: str, list_batch_outputs: List[Dict], origin_data: pd.DataFrame, rt_config: Dict)->List:
"""Process the model predict to human readable format
There are three predictor for diffrent seq_lab task dependent on the config.use_crf(the predict is already decoded to ids), and config.word_ready(subword has gathered to firstpiece)
Args:
stage: train/test/etc.
list_batch_outputs: a list of outputs
origin_data: the origin pd.DataFrame data, there are some data not be able to convert to tensor
rt_config:
>>> current status
>>> {
>>> "current_step": self.global_step,
>>> "current_epoch": self.current_epoch,
>>> "total_steps": self.num_training_steps,
>>> "total_epochs": self.num_training_epochs
>>> }
Returns:
all predicts
"""
predicts = []
if self.config.use_crf:
predicts = self.crf_predict(list_batch_outputs=list_batch_outputs, origin_data=origin_data)
elif self.config.word_ready:
predicts = self.word_predict(list_batch_outputs=list_batch_outputs, origin_data=origin_data)
else:
predicts = self.predict(list_batch_outputs=list_batch_outputs, origin_data=origin_data)
return predicts
[docs] def do_calc_metrics(self, predicts: List, stage: str, list_batch_outputs: List[Dict], origin_data: pd.DataFrame, rt_config: Dict)->Dict:
"""calc the scores use the predicts or list_batch_outputs
Args:
predicts: list of predicts
stage: train/test/etc.
list_batch_outputs: a list of outputs
origin_data: the origin pd.DataFrame data, there are some data not be able to convert to tensor
rt_config:
>>> current status
>>> {
>>> "current_step": self.global_step,
>>> "current_epoch": self.current_epoch,
>>> "total_steps": self.num_training_steps,
>>> "total_epochs": self.num_training_epochs
>>> }
Returns:
the named scores, recall, precision, f1
"""
def _flat_entities_info(entities_info: List[Dict], text: str)->Dict:
"""gather the same labeled entity to the same list
Args:
entities_info:
>>> [
>>> {
>>> "start": start1,
>>> "end": end1,
>>> "labels": ["label_1"]
>>> },
>>> {
>>> "start": start2,
>>> "end": end2,
>>> "labels": ["label_2"]
>>> },....
>>> ]
text: be labeled text
Returns:
>>> { "label_1" [text[start1:end1]], "label_2": [text[start_2: end_2]]...}
"""
info = {}
for item in entities_info:
label = item['labels'][0]
if label not in info:
info[label] = []
start_position, end_position = item['start'], item['end']
while start_position < end_position:
if text[start_position] in self.config.ignore_char:
start_position += 1
else:
break
while start_position < end_position:
if text[end_position-1] in self.config.ignore_char:
end_position -= 1
else:
break
if start_position == end_position: # if the entity after remove ignore char be null, we set it to origin
start_position, end_position = item['start'], item['end']
if self.config.ignore_position:
info[label].append(text[item['start']: item['end']].strip())
else:
info[label].append((start_position, end_position))
return info
all_predicts = []
all_ground_truths = []
for predict in predicts:
text = predict['sentence']
predict_ins = _flat_entities_info(predict['predict_entities_info'], text)
ground_truth_ins = _flat_entities_info(predict['entities_info'], text)
all_predicts.append(predict_ins)
all_ground_truths.append(ground_truth_ins)
precision, recall, f1 = self.calc_score(all_predicts, all_ground_truths)
real_name = self.loss_name_map(stage)
logger.info(f'{real_name}_precision: {precision*100}, {real_name}_recall: {recall*100}, {real_name}_f1: {f1*100}')
return {f'{real_name}_precision': precision*100, f'{real_name}_recall': recall*100, f'{real_name}_f1': f1*100}
[docs] def do_save(self, predicts: List, stage: str, list_batch_outputs: List[Dict], origin_data: pd.DataFrame, rt_config: Dict, save_condition: bool=False):
"""save the predict when save_condition==True
Args:
predicts: list of predicts
stage: train/test/etc.
list_batch_outputs: a list of outputs
origin_data: the origin pd.DataFrame data, there are some data not be able to convert to tensor
rt_config:
>>> current status
>>> {
>>> "current_step": self.global_step,
>>> "current_epoch": self.current_epoch,
>>> "total_steps": self.num_training_steps,
>>> "total_epochs": self.num_training_epochs
>>> }
save_condition: True for save, False for depend on rt_config
Returns:
None
"""
if self.config.start_save_epoch == -1 or self.config.start_save_step == -1:
self.config.start_save_step = rt_config.get('total_steps', 0) - 1
self.config.start_save_epoch = rt_config.get('total_epochs', 0) - 1
if not save_condition and (rt_config['current_step']>=self.config.start_save_step or rt_config['current_epoch']>=self.config.start_save_epoch):
save_condition = True
if save_condition:
save_path = os.path.join(self.config.save_root_path, self.config.save_path.get(stage, ''))
if "current_step" in rt_config:
save_file = os.path.join(save_path, f"step_{str(rt_config['current_step'])}_predict.json")
else:
save_file = os.path.join(save_path, 'predict.json')
logger.info(f"Save the {stage} predict data at {save_file}")
with open(save_file, 'w') as f:
json.dump(predicts, f, indent=4, ensure_ascii=False)
[docs] def calc_score(self, predict_list: List, ground_truth_list: List):
"""use predict_list and ground_truth_list to calc scores
Args:
predict_list: list of predict
ground_truth_list: list of ground_truth
Returns:
precision, recall, f1
"""
category_tp = {}
category_fp = {}
category_fn = {}
def _care_div(a, b):
"""return a/b or 0.0 if b == 0
"""
if b==0:
return 0.0
return a/b
def _calc_num(_pred: List, _ground_truth: List):
"""calc tp, fn, fp
Args:
pred: pred list
ground_truth: groud truth list
Returns:
tp, fn, fp
"""
num_p = len(_pred)
num_t = len(_ground_truth)
truth = 0
for p in _pred:
if p in _ground_truth:
truth += 1
return truth, num_t-truth, num_p-truth
for predict, ground_truth in zip(predict_list, ground_truth_list):
keys = set(list(predict.keys())+list(ground_truth.keys()))
for key in keys:
tp, fn, fp = _calc_num(predict.get(key, []), ground_truth.get(key, []))
# logger.info(f"{key} tp num {tp}, fn num {fn}, fp num {fp}")
category_tp[key] = category_tp.get(key, 0) + tp
category_fn[key] = category_fn.get(key, 0) + fn
category_fp[key] = category_fp.get(key, 0) + fp
all_tp, all_fn, all_fp = 0, 0, 0
for key in category_tp:
if key in self.config.ignore_labels:
continue
tp, fn, fp = category_tp[key], category_fn[key], category_fp[key]
all_tp += tp
all_fn += fn
all_fp += fp
precision = _care_div(tp, tp+fp)
recall = _care_div(tp, tp+fn)
f1 = _care_div(2*precision*recall, precision+recall)
logger.info(f"For entity 「{key}」, the precision={precision*100 :.2f}%, the recall={recall*100:.2f}%, f1={f1*100:.2f}%")
precision = _care_div(all_tp,all_tp+all_fp)
recall = _care_div(all_tp, all_tp+all_fn)
f1 = _care_div(2*precision*recall, precision+recall)
return precision, recall, f1
[docs] def get_entity_info(self, sub_tokens_index: List, offset_mapping: List, word_ids: List, label: str)->Dict:
"""gather sub_tokens to get the start and end
Args:
sub_tokens_index: the entity tokens index list
offset_mapping: every token offset in text
word_ids: every token in the index of words
label: predict label
Returns:
entity_info
"""
if (not sub_tokens_index) or (not label) or (label in self.config.ignore_labels):
return {}
start = offset_mapping[sub_tokens_index[0]][0]
end = offset_mapping[sub_tokens_index[-1]][1]
return {
"start": start,
"end": end,
"labels": [label]
}
def _process4predict(self, predict: torch.LongTensor, index: int, origin_data: pd.DataFrame)->Dict:
"""gather the predict and origin text and ground_truth_entities_info for predict
Args:
predict: the predict label_ids
index: the data index in origin_data
origin_data: the origin pd.DataFrame
Returns:
>>> one_ins info
>>> {
>>> "sentence": "...",
>>> "uuid": "..",
>>> "entities_info": [".."],
>>> "predict_entities_info": [".."],
>>> }
"""
one_ins = {}
origin_ins = origin_data.iloc[int(index)]
one_ins["sentence"] = origin_ins[self.config.sentence]
one_ins["uuid"] = origin_ins[self.config.uuid]
one_ins["entities_info"] = origin_ins[self.config.entities_info]
word_ids = origin_ins[self.config.word_ids]
rel_token_len = len(word_ids)
offset_mapping = origin_ins[self.config.offsets][:rel_token_len]
predict = list(predict[:rel_token_len])
# predict = list(origin_ins['label_ids'][:rel_token_len])
predict_entities_info = []
pre_label = ''
sub_tokens_index = []
for i, label_id in enumerate(predict):
if offset_mapping[i] == (0, 0): # added token like [CLS]/<s>/..
continue
label = self.config.label_vocab[label_id]
if label in self.config.ignore_labels \
or (label[0]=='B') \
or (label.split('-')[-1] != pre_label): # label == "O" or label=='B' or label.tail != previor_label
entity_info = self.get_entity_info(sub_tokens_index, offset_mapping, word_ids, pre_label)
if entity_info:
predict_entities_info.append(entity_info)
pre_label = ''
sub_tokens_index = []
if label not in self.config.ignore_labels:
assert len(label.split('-')) == 2
pre_label = label.split('-')[-1]
sub_tokens_index.append(i)
entity_info = self.get_entity_info(sub_tokens_index, offset_mapping, word_ids, pre_label)
if entity_info:
predict_entities_info.append(entity_info)
one_ins['predict_entities_info'] = predict_entities_info
return one_ins
[docs] def crf_predict(self, list_batch_outputs: List[Dict], origin_data: pd.DataFrame)->List:
"""use the crf predict label_ids get predict info
Args:
list_batch_outputs: the crf predict info
origin_data: the origin data
Returns:
all predict instances info
"""
if self.config.sentence not in origin_data:
logger.error(f"{self.config.sentence} not in the origin data")
raise PermissionError(f"{self.config.sentence} must be provided")
if self.config.uuid not in origin_data:
logger.error(f"{self.config.uuid} not in the origin data")
raise PermissionError(f"{self.config.uuid} must be provided")
if self.config.entities_info not in origin_data:
logger.error(f"{self.config.entities_info} not in the origin data")
raise PermissionError(f"{self.config.entities_info} must be provided")
predicts = []
for outputs in list_batch_outputs:
batch_predict = outputs[self.config.predict_seq_label]
# batch_special_tokens_mask = outputs[self.config.special_tokens_mask]
indexes = list(outputs[self.config._index])
outputs = []
for predict, index in zip(batch_predict, indexes):
one_ins = self._process4predict(predict, index, origin_data)
predicts.append(one_ins)
return predicts
[docs] def word_predict(self, list_batch_outputs: List[Dict], origin_data: pd.DataFrame)->List:
"""use the firstpiece or whole word predict label_logits get predict info
Args:
list_batch_outputs: the predict labels logits info
origin_data: the origin data
Returns:
all predict instances info
"""
if self.config.sentence not in origin_data:
logger.error(f"{self.config.sentence} not in the origin data")
raise PermissionError(f"{self.config.sentence} must be provided")
if self.config.uuid not in origin_data:
logger.error(f"{self.config.uuid} not in the origin data")
raise PermissionError(f"{self.config.uuid} must be provided")
if self.config.entities_info not in origin_data:
logger.error(f"{self.config.entities_info} not in the origin data")
raise PermissionError(f"{self.config.entities_info} must be provided")
predicts = []
for outputs in list_batch_outputs:
batch_logits = outputs[self.config.logits].detach().cpu().numpy()
# batch_special_tokens_mask = outputs[self.config.special_tokens_mask]
indexes = list(outputs[self.config._index])
outputs = []
for logits, index in zip(batch_logits, indexes):
origin_ins = origin_data.iloc[int(index)]
word_ids = origin_ins[self.config.word_ids]
rel_token_len = len(word_ids)
logits = logits[:rel_token_len]
predict = logits.argmax(-1)
one_ins = self._process4predict(predict, index, origin_data)
predicts.append(one_ins)
return predicts
[docs] def predict(self, list_batch_outputs: List[Dict], origin_data: pd.DataFrame)->List:
"""general predict process (especially for subword)
Args:
list_batch_outputs: the predict (sub-)labels logits info
origin_data: the origin data
Returns:
all predict instances info
"""
if self.config.sentence not in origin_data:
logger.error(f"{self.config.sentence} not in the origin data")
raise PermissionError(f"{self.config.sentence} must be provided")
if self.config.uuid not in origin_data:
logger.error(f"{self.config.uuid} not in the origin data")
raise PermissionError(f"{self.config.uuid} must be provided")
if self.config.entities_info not in origin_data:
logger.error(f"{self.config.entities_info} not in the origin data")
raise PermissionError(f"{self.config.entities_info} must be provided")
predicts = []
for outputs in list_batch_outputs:
batch_logits = outputs[self.config.logits].detach().cpu().numpy()
# batch_special_tokens_mask = outputs[self.config.special_tokens_mask]
indexes = list(outputs[self.config._index])
outputs = []
for logits, index in zip(batch_logits, indexes):
one_ins = {}
origin_ins = origin_data.iloc[int(index)]
input_ids = origin_ins[self.config.input_ids]
one_ins["sentence"] = origin_ins[self.config.sentence]
one_ins["uuid"] = origin_ins[self.config.uuid]
one_ins["entities_info"] = origin_ins[self.config.entities_info]
rel_token_len = len(input_ids)
special_tokens_mask = np.array(origin_data.iloc[int(index)][self.config.special_tokens_mask][:rel_token_len])
offset_mapping = origin_data.iloc[int(index)][self.config.offsets][:rel_token_len]
logits = logits[:rel_token_len]
entity_idx = logits.argmax(-1)
labels = []
for i, idx in enumerate(list(entity_idx)):
labels.append(self.config.label_vocab[idx])
maxes = np.max(logits, axis=-1, keepdims=True)
shifted_exp = np.exp(logits - maxes)
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
pre_entities = self.gather_pre_entities(
one_ins["sentence"], input_ids, scores, offset_mapping, special_tokens_mask)
grouped_entities = self.aggregate(pre_entities, self.config.aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in self.config.ignore_labels
and entity.get("entity_group", None) not in self.config.ignore_labels
]
predict_entities_info = []
for entity in entities:
one_predict = {}
one_predict['start'] = entity['start']
one_predict['end'] = entity['end']
one_predict['labels'] = [entity['entity_group']]
predict_entities_info.append(one_predict)
one_ins['predict_entities_info'] = predict_entities_info
predicts.append(one_ins)
return predicts
[docs] def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
entities = []
for pre_entity in pre_entities:
entity_idx = pre_entity["scores"].argmax()
score = pre_entity["scores"][entity_idx]
entity = {
"entity": self.config.label_vocab[entity_idx],
"score": score,
"index": pre_entity["index"],
"word": pre_entity["word"],
"start": pre_entity["start"],
"end": pre_entity["end"],
}
entities.append(entity)
else:
entities = self.aggregate_words(pre_entities, aggregation_strategy)
if aggregation_strategy == AggregationStrategy.NONE:
return entities
return self.group_entities(entities)
[docs] def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict:
word = self.tokenizer.decode([self.tokenizer.token_to_id(entity['word']) for entity in entities])
if aggregation_strategy == AggregationStrategy.FIRST:
scores = entities[0]["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.config.label_vocab[idx]
elif aggregation_strategy == AggregationStrategy.MAX:
max_entity = max(entities, key=lambda entity: entity["scores"].max())
scores = max_entity["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.config.label_vocab[idx]
elif aggregation_strategy == AggregationStrategy.AVERAGE:
scores = np.stack([entity["scores"] for entity in entities])
average_scores = np.nanmean(scores, axis=0)
entity_idx = average_scores.argmax()
entity = self.config.label_vocab[entity_idx]
score = average_scores[entity_idx]
else:
raise ValueError("Invalid aggregation_strategy")
new_entity = {
"entity": entity,
"score": score,
"word": word,
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return new_entity
[docs] def group_sub_entities(self, entities: List[dict]) -> dict:
"""Group together the adjacent tokens with the same entity predicted.
Args:
entities: The entities predicted by the pipeline.
"""
# Get the first entity in the entity group
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]
entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": " ".join(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return entity_group
[docs] def get_tag(self, entity_name: str) -> Tuple[str, str]:
if entity_name.startswith("B-"):
bi = "B"
tag = entity_name[2:]
elif entity_name.startswith("I-"):
bi = "I"
tag = entity_name[2:]
else:
# It's not in B-, I- format
# Default to I- for continuation.
bi = "I"
tag = entity_name
return bi, tag
[docs] def group_entities(self, entities: List[dict]) -> List[dict]:
"""Find and group together the adjacent tokens with the same entity predicted.
Args:
entities: The entities predicted by the pipeline.
"""
entity_groups = []
entity_group_disagg = []
for entity in entities:
if not entity_group_disagg:
entity_group_disagg.append(entity)
continue
# If the current entity is similar and adjacent to the previous entity,
# append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" prefixes
# Shouldn't merge if both entities are B-type
bi, tag = self.get_tag(entity["entity"])
last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])
if tag == last_tag and bi != "B":
# Modify subword type to be previous_type
entity_group_disagg.append(entity)
else:
# If the current entity is different from the previous entity
# aggregate the disaggregated entity group
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity]
if entity_group_disagg:
# it's the last entity, add it to the entity groups
entity_groups.append(self.group_sub_entities(entity_group_disagg))
return entity_groups
[docs] def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
"""Override tokens from a given word that disagree to force agreement on word boundaries.
Example:
micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
company| B-ENT I-ENT
"""
if aggregation_strategy in {
AggregationStrategy.NONE,
AggregationStrategy.SIMPLE,
}:
raise ValueError("NONE and SIMPLE strategies are invalid for word aggregation")
word_entities = []
word_group = None
for entity in entities:
if word_group is None:
word_group = [entity]
elif entity["is_subword"]:
word_group.append(entity)
else:
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
word_group = [entity]
# Last item
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
return word_entities
[docs] def gather_pre_entities(
self,
sentence: str,
input_ids: np.ndarray,
scores: np.ndarray,
offset_mapping: Optional[List[Tuple[int, int]]],
special_tokens_mask: np.ndarray,
) -> List[dict]:
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
for idx, token_scores in enumerate(scores):
# Filter special_tokens, they should only occur
# at the sentence boundaries since we're not encoding pairs of
# sentences so we don't have to keep track of those.
if special_tokens_mask[idx]:
continue
word = self.tokenizer.id_to_token(int(input_ids[idx]))
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
if getattr(self.tokenizer.model, "continuing_subword_prefix", None):
# This is a BPE, word aware tokenizer, there is a correct way
# to fuse tokens
is_subword = len(word) != len(word_ref)
else:
# This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False
else:
start_ind = None
end_ind = None
is_subword = False
pre_entity = {
"word": word,
"scores": token_scores,
"start": start_ind,
"end": end_ind,
"index": idx,
"is_subword": is_subword,
}
pre_entities.append(pre_entity)
return pre_entities