Source code for dlk.core.modules.crf

# Copyright 2021 cstsunfu. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from typing import Callable, Dict, List
from . import module_register, module_config_register, Module
from dlk.utils.config import BaseConfig

[docs]@module_config_register("crf") class CRFConfig(BaseConfig): """Config for ConditionalRandomField Config Example: >>> { >>> "config": { >>> "output_size": 2, >>> "batch_first": true, >>> "reduction": "mean", //none|sum|mean|token_mean >>> }, >>> "_name": "crf", >>> } """ def __init__(self, config: Dict): super(CRFConfig, self).__init__(config) config = config['config'] self.output_size = config['output_size'] if self.output_size <= 0: raise ValueError(f'invalid number of tags: {self.output_size}') self.post_check(config, used=[ "output_size", "batch_first", "reduction", ])
[docs]@module_register("crf") class ConditionalRandomField(Module): """ CRF, training_step for training, forward for decode。 """ def __init__(self, config: CRFConfig): super(ConditionalRandomField, self).__init__() self.num_tags = config.output_size self.transitions = nn.parameter.Parameter(torch.randn(self.num_tags, self.num_tags)) self.start_transitions = nn.parameter.Parameter(torch.randn(self.num_tags)) self.end_transitions = nn.parameter.Parameter(torch.randn(self.num_tags))
[docs] def init_weight(self, method: Callable): """init the weight of transitions, start_transitions and end_transitions Initialize the transition parameters. The parameters will be initialized randomly from a uniform distribution between -0.1 and 0.1. Args: method: init method, no use Returns: None """ nn.init.normal_(self.transitions, -1, 0.1) nn.init.uniform_(self.start_transitions, -0.1, 0.1) nn.init.uniform_(self.end_transitions, -0.1, 0.1)
def _normalizer_likelihood(self, logits: torch.FloatTensor, mask: torch.ByteTensor): """Computes the (batch_size,) denominator term for the log-likelihood. The sum of the likelihoods across all possible state sequences. Args: logits: max_len*batch_size*num_tags mask: max_len*batch_size Returns: batch_size*every sum """ seq_len, batch_size, n_tags = logits.size() alpha = logits[0] alpha = alpha + self.start_transitions.view(1, -1) flip_mask = mask.eq(False) for i in range(1, seq_len): emit_score = logits[i].view(batch_size, 1, n_tags) trans_score = self.transitions.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) alpha = alpha + self.end_transitions.view(1, -1) return torch.logsumexp(alpha, 1) def _gold_score(self, logits: torch.FloatTensor, tags: torch.LongTensor, mask: torch.ByteTensor): """ Compute the score for the gold path. Args: logits: max_len*batch_size*num_tags tags: max_len*batch_size mask: max_len*batch_size Returns: batch_size*every_gold_score """ seq_len, batch_size, _ = logits.size() batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) # trans_socre [L-1, B] mask = mask.eq(True) flip_mask = mask.eq(False) trans_score = self.transitions[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) # emit_score [L, B] emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) # score [L-1, B] score = trans_score + emit_score[:seq_len - 1, :] score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) st_scores = self.start_transitions.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] last_idx = mask.long().sum(0) - 1 ed_scores = self.end_transitions.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] score = score + st_scores + ed_scores # return [B,] return score
[docs] def training_step(self, logits: torch.FloatTensor, tags: torch.LongTensor, mask: torch.LongTensor): """training step, calc the loss Args: logits: emissions, batch_size*max_len*num_tags tags: batch_size*max_len mask: batch_size*max_len, mask==0 means padding Returns: loss """ logits = logits.transpose(0, 1) tags = tags.transpose(0, 1).long() mask = mask.transpose(0, 1).byte() all_path_score = self._normalizer_likelihood(logits, mask) gold_path_score = self._gold_score(logits, tags, mask) loss = all_path_score - gold_path_score return loss.mean()
[docs] def forward(self, logits: torch.FloatTensor, mask: torch.LongTensor): """predict step, get the best path Args: logits: emissions, batch_size*max_len*num_tags mask: batch_size*max_len, mask==0 means padding Returns: batch*max_len """ logits = logits.transpose(0, 1) # L, B, H mask = mask.transpose(0, 1) return self._viterbi_decode(logits, mask)
def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.LongTensor) -> torch.Tensor: """predict step, get the best path Args: logits: emissions, max_len*batch_size*num_tags mask: max_len*batch_size, mask==0 means padding Returns: batch*max_len """ # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].all() seq_length, batch_size = mask.shape # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] history = torch.jit.annotate(List[int], []) # score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends # with tag j # history saves where the best tags candidate transitioned from; this is used # when we trace back the best tag sequence # Viterbi algorithm recursive case: we compute the score of the best tag sequence # for every possible next tag for i in range(1, seq_length): # Broadcast viterbi score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emission = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the score of the best # tag sequence so far that ends with transitioning from tag i to tag j and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emission # Find the maximum score over all possible current tag # shape: (batch_size, num_tags) next_score, indices = next_score.max(dim=1) # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags) score = torch.where(mask.bool()[i].unsqueeze(1), next_score, score) history.append(indices) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Now, compute the best path for each sample # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 best_tags_list = [] best_tags_list = torch.jit.annotate(List[List[int]], []) for idx in range(batch_size): # Find the tag which maximizes the score at the last timestep; this is our best tag # for the last timestep _, best_last_tag = score[idx].max(dim=0) best_tags = [best_last_tag.item()] # We trace back where the best last tag comes from, append that to our best tag # sequence, and trace it back again, and so on for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag.item()) # Reverse the order because we start from the last timestep best_tags.reverse() best_tags_list.append(best_tags) output = torch.jit.annotate(List[List[int]], []) for tag_list in best_tags_list: if len(tag_list)<seq_length: tag_list = tag_list + [-1]*(seq_length-len(tag_list)) output.append(tag_list) return torch.tensor(output, dtype=torch.long, device=mask.device)