Source code for dlk.data.postprocessors.txt_cls

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hjson
import pickle as pkl
import os
import json
import pandas as pd
import numpy as np
from typing import Union, Dict, Any, List
from dlk.data.postprocessors import postprocessor_register, postprocessor_config_register, IPostProcessor, IPostProcessorConfig
from dlk.utils.logger import Logger
import torch
from dlk.utils.vocab import Vocabulary
from dlk.utils.io import open

logger = Logger.get_logger()


[docs]@postprocessor_config_register('txt_cls') class TxtClsPostProcessorConfig(IPostProcessorConfig): """Config for TxtClsPostProcessor Config Example: >>> { >>> "_name": "txt_cls", >>> "config": { >>> "meta": "*@*", >>> "meta_data": { >>> "label_vocab": 'label_vocab', >>> }, >>> "input_map": { >>> "logits": "logits", >>> "label_ids": "label_ids" >>> "_index": "_index", >>> }, >>> "origin_input_map": { >>> "sentence": "sentence", >>> "sentence_a": "sentence_a", // for pair >>> "sentence_b": "sentence_b", >>> "uuid": "uuid" >>> }, >>> "save_root_path": ".", //save data root dir >>> "top_k": 1, //the result return top k result >>> "data_type": "single", //single or pair >>> "save_path": { >>> "valid": "valid", // relative dir for valid stage >>> "test": "test", // relative dir for test stage >>> }, >>> "start_save_step": 0, // -1 means the last >>> "start_save_epoch": -1, >>> } >>> } """ def __init__(self, config: Dict): super(TxtClsPostProcessorConfig, self).__init__(config) self.data_type = self.config['data_type'] assert self.data_type in {'single', 'pair'} if self.data_type == 'pair': self.sentence_a = self.origin_input_map['sentence_a'] self.sentence_b = self.origin_input_map['sentence_b'] else: self.sentence = self.origin_input_map['sentence'] self.uuid = self.origin_input_map['uuid'] self.logits = self.input_map['logits'] self.label_ids = self.input_map['label_ids'] self._index = self.input_map['_index'] self.top_k = self.config['top_k'] if isinstance(self.config['meta'], str): with open(self.config['meta'], 'rb') as f: meta = pkl.load(f) else: raise PermissionError("You must provide meta data for txt_cls postprocess.") trace_path = [] trace_path_str = self.config['meta_data']['label_vocab'] if trace_path_str and trace_path_str.strip()!='.': trace_path = trace_path_str.split('.') self.label_vocab = meta for trace in trace_path: self.label_vocab = self.label_vocab[trace] self.label_vocab = Vocabulary.load(self.label_vocab) self.save_path = self.config['save_path'] self.save_root_path = self.config['save_root_path'] self.start_save_epoch = self.config['start_save_epoch'] self.start_save_step = self.config['start_save_step'] self.post_check(self.config, used=[ "meta", "meta_data", "input_map", "origin_input_map", "save_root_path", "top_k", "data_type", "save_path", "start_save_step", "start_save_epoch", ])
[docs]@postprocessor_register('txt_cls') class TxtClsPostProcessor(IPostProcessor): """postprocess for text classfication""" def __init__(self, config: TxtClsPostProcessorConfig): super(TxtClsPostProcessor, self).__init__() self.config = config self.label_vocab = self.config.label_vocab
[docs] def do_predict(self, stage: str, list_batch_outputs: List[Dict], origin_data: pd.DataFrame, rt_config: Dict)->List: """Process the model predict to human readable format Args: stage: train/test/etc. list_batch_outputs: a list of outputs origin_data: the origin pd.DataFrame data, there are some data not be able to convert to tensor rt_config: >>> current status >>> { >>> "current_step": self.global_step, >>> "current_epoch": self.current_epoch, >>> "total_steps": self.num_training_steps, >>> "total_epochs": self.num_training_epochs >>> } Returns: all predicts """ results = [] for outputs in list_batch_outputs: logits = outputs[self.config.logits].detach() assert len(logits.shape) == 2 # predict_indexes = list(torch.argmax(logits, 1)) indexes = list(outputs[self.config._index]) if self.config.label_ids in outputs: label_ids = outputs[self.config.label_ids] else: label_ids = [None] * len(indexes) for i, (one_logits, index, label_id) in enumerate(zip(logits, indexes, label_ids)): one_ins = {} one_logits = torch.softmax(one_logits, -1) one_origin = origin_data.iloc[int(index)] if self.config.data_type == 'single': sentence = one_origin[self.config.sentence] one_ins['sentence'] = sentence else: sentence_a = one_origin[self.config.sentence_a] one_ins['sentence_a'] = sentence_a sentence_b = one_origin[self.config.sentence_b] one_ins['sentence_b'] = sentence_b uuid = one_origin[self.config.uuid] label_values, label_indeies = torch.topk(one_logits, self.config.top_k, dim=-1) predict = {} for i, (label_value, label_index) in enumerate(zip(label_values, label_indeies)): label_name = self.label_vocab.get_word(label_index) predict[i] = [label_name, float(label_value)] ground_truth = [] if label_id is not None and label_id.shape: for one_label_id in label_id: ground_truth.append(self.label_vocab.get_word(one_label_id)) elif label_id: ground_truth.append(self.label_vocab.get_word(label_id)) one_ins['uuid'] = uuid one_ins['labels'] = ground_truth one_ins['predicts'] = predict one_ins['predict_extend_return'] = self.gather_predict_extend_data(outputs, i, self.config.predict_extend_return) results.append(one_ins) return results
[docs] def do_calc_metrics(self, predicts: List, stage: str, list_batch_outputs: List[Dict], origin_data: pd.DataFrame, rt_config: Dict)->Dict: """calc the scores use the predicts or list_batch_outputs Args: predicts: list of predicts stage: train/test/etc. list_batch_outputs: a list of outputs origin_data: the origin pd.DataFrame data, there are some data not be able to convert to tensor rt_config: >>> current status >>> { >>> "current_step": self.global_step, >>> "current_epoch": self.current_epoch, >>> "total_steps": self.num_training_steps, >>> "total_epochs": self.num_training_epochs >>> } Returns: the named scores, acc """ right_num = 0 for one_ins in predicts: labels = one_ins['labels'] assert len(labels) == 1, "We currently is not support multi label in txt_cls postprocess" label = labels[0] one_predicts = one_ins['predicts'] predict_label, predict_value = one_predicts[0] # the first predict if label == predict_label: right_num += 1 real_name = self.loss_name_map(stage) return {f'{real_name}_acc': right_num/len(predicts)}
[docs] def do_save(self, predicts: List, stage: str, list_batch_outputs: List[Dict], origin_data: pd.DataFrame, rt_config: Dict, save_condition: bool=False): """save the predict when save_condition==True Args: predicts: list of predicts stage: train/test/etc. list_batch_outputs: a list of outputs origin_data: the origin pd.DataFrame data, there are some data not be able to convert to tensor rt_config: >>> current status >>> { >>> "current_step": self.global_step, >>> "current_epoch": self.current_epoch, >>> "total_steps": self.num_training_steps, >>> "total_epochs": self.num_training_epochs >>> } save_condition: True for save, False for depend on rt_config Returns: None """ if self.config.start_save_epoch == -1 or self.config.start_save_step == -1: self.config.start_save_step = rt_config.get('total_steps', 0) - 1 self.config.start_save_epoch = rt_config.get('total_epochs', 0) - 1 if not save_condition and (rt_config['current_step']>=self.config.start_save_step or rt_config['current_epoch']>=self.config.start_save_epoch): save_condition = True if save_condition: save_path = os.path.join(self.config.save_root_path, self.config.save_path.get(stage, '')) if "current_step" in rt_config: save_file = os.path.join(save_path, f"step_{str(rt_config['current_step'])}_predict.json") else: save_file = os.path.join(save_path, 'predict.json') logger.info(f"Save the {stage} predict data at {save_file}") with open(save_file, 'w') as f: json.dump(predicts, f, indent=4, ensure_ascii=False)