Source code for dlk.core.losses.multi_loss

# Copyright 2021 cstsunfu. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: WIP  comming soon!!
from typing import Dict
import torch.nn as nn
from . import loss_register, loss_config_register
import torch.nn as nn
from dlk.utils.config import ConfigTool


[docs]@loss_config_register("multi_loss") class MultiLossConfig(object): """Config for MultiLoss Config Example: >>> { >>> "loss@the_first": { >>> config: { >>> "ignore_index": -1, >>> "weight": null, # or a list of value for every class >>> "label_smoothing": 0.0, # torch>=1.10 >>> "pred_truth_pair": ["logits1", "label1"], # len(.) == 2, the 1st is the pred_name, 2nd is truth_name in __call__ inputs >>> "schedule": [0.3, 0.6, 1], >>> "scale": [1, 0, 0.5], # scale the loss for every schedule >>> // "schdeule": [0.3, 1.0], >>> // "scale": [0, 1, 0.5], # scale the loss >>> }, >>> _name: "cross_entropy", >>> }, >>> "loss@the_second": { >>> config: { >>> "pred_truth_pair": ["logits2", "label2"], # len(.) == 2, the 1st is the pred_name, 2nd is truth_name in __call__ inputs >>> "schdeule": [0.3, 0.6, 1], >>> "scale": [0, 1, 0.5], # scale the loss for every schedule >>> // "schdeule": [0.3, 1.0], >>> // "scale": [0, 1, 0.5], # scale the loss >>> }, >>> _base: "cross_entropy", // _name or _base is all ok >>> }, >>> config: { >>> "loss_list": ['the_first', 'the_second'], >>> }, >>> _name: "cross_entropy", >>> } """ def __init__(self, config: Dict): super(MultiLossConfig, self).__init__() config = config.get('config', {})
[docs]@loss_register("multi_loss") class MultiLoss(object): """ This module is NotImplemented yet don't use it """ def __init__(self, config: MultiLossConfig): super(MultiLoss, self).__init__() raise NotImplementedError self.config = config
[docs] def get_loss(self, config): """Use config to init the loss Args: config: loss config Returns: the Loss and the LossConfig """ return ConfigTool.get_leaf_module(loss_register, loss_config_register, "loss", config)
[docs] def calc(self, result, inputs, rt_config): """calc the loss the predict is from result, the ground truth is from inputs Args: result: the model predict dict inputs: the all inputs for model rt_config: provide the current training status >>> { >>> "current_step": self.global_step, >>> "current_epoch": self.current_epoch, >>> "total_steps": self.num_training_steps, >>> "total_epochs": self.num_training_epochs >>> } Returns: loss """ loss = 0 for i, (pred, truth) in enumerate(self.config.pred_truth_pair): loss = loss + self.cross_entropy(result[pred], inputs[truth]) * self.config.loss_scale[i] return loss
def __call__(self, result, inputs, rt_config): """same as self.calc """ return self.calc(result, inputs, rt_config)