Loading [MathJax]/jax/output/CommonHTML/jax.js

7. 多标签分类损失pytorch 实现

1. 多标签分类损失 pytorch 实现

在文本分类中使用多标签分类损失函数实现,公式为 (来自于 将“softmax+ 交叉熵”推广到多标签分类问题 ), 具体推导可看,论文GLOBAL POINTER: NOVEL EFFICIENT SPAN-BASED APPROACH FOR NAMED ENTITY RECOGNITION 中 3.Class Imbalance Loss 部分。:
log(1+iΩnegesi)+log(1+jΩposesj)
其中 Ωpos,Ωneg 分别是样本的正负类别集合。

import torch
import numpy as np
def multilabel_categorical_crossentropy(y_true, y_pred):
"""
代码第 1, 2, 3 行解释
1. 将真实标签 y_true 从 0 / 1 映射到 -1/1,即将正类设为 -1,负类设为 1。并将得到的结果与预测值相乘。这一步处理的目的是为了保证预测值 y_pred
落在 [0, 1] 的范围内。2. 将正类位置的预测值设为负无穷。在这里采用了一个技巧,即将正类的预测值减去一个很大的数,这里是 1e12。这样在经过 logsumexp 计算后,正类的概率会趋近于 0,达到屏蔽正类的目的。3. 将负类位置的预测值设为负无穷
同样采用上述技巧,将负类的预测值减去一个很大的数,使得在经过 logsumexp
函数计算后,负类的概率会趋近于 0,达到屏蔽负类的目的。说明:y_true 和 y_pred 的 shape 一致,y_true 的元素非 0 即 1,1 表示对应的类为目标类,0 表示对应的类为非目标类。警告:请保证 y_pred 的值域是全体实数,换言之一般情况下 y_pred
不用加激活函数,尤其是不能加 sigmoid 或者 softmax!预测
阶段则输出 y_pred 大于 0 的类。如有疑问,请仔细阅读并理解
参考 1。"""
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12 #mask pos
y_pred_pos = y_pred - (1-y_true) * 1e12 #mask neg
#构建 y_pred[..., :1]一样形状的全 0tensor 来替换对应位置的预测值
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return (neg_loss + pos_loss).mean()
y_pred = torch.tensor([[0.9, 0.1, 0.8, 0.3], [0.3, 0.2, 0.5, 0.7], [0.7, 0.6, 0.1, 0.3]])
y_true = torch.tensor([[1, 0, 1, 0], [0, 0, 1, 1], [1, 0, 0, 0]])
loss = multilabel_categorical_crossentropy(y_true, y_pred)
print(loss.item())
#1.9803704023361206

注意:其只用于硬标签,像 label smoothing, mixup 都不能使用。详见,多标签“Softmax+ 交叉熵”。稀疏版见参考 2。

参考

[1] 将“softmax+ 交叉熵”推广到多标签分类问题

[2] 稀疏版多标签分类交叉熵损失函数

[3] 多标签损失之 Hamming Loss、Focal Loss、交叉熵和 ASL 损失

正文完
 
admin
版权声明:本站原创文章,由 admin 2023-11-26发表,共计1682字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。
评论(没有评论)
验证码