7. 多标签分类损失pytorch 实现

1. 多标签分类损失 pytorch 实现

在文本分类中使用多标签分类损失函数实现,公式为 (来自于 将“softmax+ 交叉熵”推广到多标签分类问题 ), 具体推导可看,论文GLOBAL POINTER: NOVEL EFFICIENT SPAN-BASED APPROACH FOR NAMED ENTITY RECOGNITION 中 3.Class Imbalance Loss 部分。:
$$
\begin{equation}\log \left(1 + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\right) + \log \left(1 + \sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)\label{eq:final}\end{equation} \tag{1}
$$
其中 $\Omega_{pos}, \Omega_{neg}$ 分别是样本的正负类别集合。

import torch
import numpy as np

def multilabel_categorical_crossentropy(y_true, y_pred):
    """
    代码第 1, 2, 3 行解释
    1. 将真实标签 y_true 从 0 / 1 映射到 -1/1,即将正类设为 -1,负类设为 1。并将得到的结果与预测值相乘。这一步处理的目的是为了保证预测值 y_pred
    落在 [0, 1] 的范围内。2. 将正类位置的预测值设为负无穷。在这里采用了一个技巧,即将正类的预测值减去一个很大的数,这里是 1e12。这样在经过 logsumexp 计算后,正类的概率会趋近于 0,达到屏蔽正类的目的。3. 将负类位置的预测值设为负无穷
    同样采用上述技巧,将负类的预测值减去一个很大的数,使得在经过 logsumexp
    函数计算后,负类的概率会趋近于 0,达到屏蔽负类的目的。说明:y_true 和 y_pred 的 shape 一致,y_true 的元素非 0 即 1,1 表示对应的类为目标类,0 表示对应的类为非目标类。警告:请保证 y_pred 的值域是全体实数,换言之一般情况下 y_pred
         不用加激活函数,尤其是不能加 sigmoid 或者 softmax!预测
         阶段则输出 y_pred 大于 0 的类。如有疑问,请仔细阅读并理解
        参考 1。"""
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12 #mask pos
    y_pred_pos = y_pred - (1-y_true) * 1e12 #mask neg
    #构建 y_pred[..., :1]一样形状的全 0tensor 来替换对应位置的预测值
    zeros = torch.zeros_like(y_pred[..., :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    return (neg_loss + pos_loss).mean()

y_pred = torch.tensor([[0.9, 0.1, 0.8, 0.3], [0.3, 0.2, 0.5, 0.7], [0.7, 0.6, 0.1, 0.3]])
y_true = torch.tensor([[1, 0, 1, 0], [0, 0, 1, 1], [1, 0, 0, 0]])
loss  = multilabel_categorical_crossentropy(y_true, y_pred)
print(loss.item())
#1.9803704023361206

注意:其只用于硬标签,像 label smoothing, mixup 都不能使用。详见,多标签“Softmax+ 交叉熵”。稀疏版见参考 2。

参考

[1] 将“softmax+ 交叉熵”推广到多标签分类问题

[2] 稀疏版多标签分类交叉熵损失函数

[3] 多标签损失之 Hamming Loss、Focal Loss、交叉熵和 ASL 损失

正文完
 
admin
版权声明:本站原创文章,由 admin 2023-11-26发表,共计1682字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。
评论(没有评论)
验证码