1. 多标签分类损失 pytorch 实现
在文本分类中使用多标签分类损失函数实现,公式为 (来自于 将“softmax+ 交叉熵”推广到多标签分类问题 ), 具体推导可看,论文GLOBAL POINTER: NOVEL EFFICIENT SPAN-BASED APPROACH FOR NAMED ENTITY RECOGNITION 中 3.Class Imbalance Loss 部分。:
$$
\begin{equation}\log \left(1 + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\right) + \log \left(1 + \sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)\label{eq:final}\end{equation} \tag{1}
$$
其中 $\Omega_{pos}, \Omega_{neg}$ 分别是样本的正负类别集合。
import torch
import numpy as np
def multilabel_categorical_crossentropy(y_true, y_pred):
"""
代码第 1, 2, 3 行解释
1. 将真实标签 y_true 从 0 / 1 映射到 -1/1,即将正类设为 -1,负类设为 1。并将得到的结果与预测值相乘。这一步处理的目的是为了保证预测值 y_pred
落在 [0, 1] 的范围内。2. 将正类位置的预测值设为负无穷。在这里采用了一个技巧,即将正类的预测值减去一个很大的数,这里是 1e12。这样在经过 logsumexp 计算后,正类的概率会趋近于 0,达到屏蔽正类的目的。3. 将负类位置的预测值设为负无穷
同样采用上述技巧,将负类的预测值减去一个很大的数,使得在经过 logsumexp
函数计算后,负类的概率会趋近于 0,达到屏蔽负类的目的。说明:y_true 和 y_pred 的 shape 一致,y_true 的元素非 0 即 1,1 表示对应的类为目标类,0 表示对应的类为非目标类。警告:请保证 y_pred 的值域是全体实数,换言之一般情况下 y_pred
不用加激活函数,尤其是不能加 sigmoid 或者 softmax!预测
阶段则输出 y_pred 大于 0 的类。如有疑问,请仔细阅读并理解
参考 1。"""
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12 #mask pos
y_pred_pos = y_pred - (1-y_true) * 1e12 #mask neg
#构建 y_pred[..., :1]一样形状的全 0tensor 来替换对应位置的预测值
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return (neg_loss + pos_loss).mean()
y_pred = torch.tensor([[0.9, 0.1, 0.8, 0.3], [0.3, 0.2, 0.5, 0.7], [0.7, 0.6, 0.1, 0.3]])
y_true = torch.tensor([[1, 0, 1, 0], [0, 0, 1, 1], [1, 0, 0, 0]])
loss = multilabel_categorical_crossentropy(y_true, y_pred)
print(loss.item())
#1.9803704023361206
注意:其只用于硬标签,像 label smoothing, mixup 都不能使用。详见,多标签“Softmax+ 交叉熵”。稀疏版见参考 2。
参考
[2] 稀疏版多标签分类交叉熵损失函数
正文完
发表至: NLP
2023-11-26