1. 多标签分类损失 pytorch 实现
在文本分类中使用多标签分类损失函数实现,公式为 (来自于 将“softmax+ 交叉熵”推广到多标签分类问题 ), 具体推导可看,论文GLOBAL POINTER: NOVEL EFFICIENT SPAN-BASED APPROACH FOR NAMED ENTITY RECOGNITION 中 3.Class Imbalance Loss 部分。:
log(1+∑i∈Ωnegesi)+log(1+∑j∈Ωpose−sj)
其中 Ωpos,Ωneg 分别是样本的正负类别集合。
import torch | |
import numpy as np | |
def multilabel_categorical_crossentropy(y_true, y_pred): | |
""" | |
代码第 1, 2, 3 行解释 | |
1. 将真实标签 y_true 从 0 / 1 映射到 -1/1,即将正类设为 -1,负类设为 1。并将得到的结果与预测值相乘。这一步处理的目的是为了保证预测值 y_pred | |
落在 [0, 1] 的范围内。2. 将正类位置的预测值设为负无穷。在这里采用了一个技巧,即将正类的预测值减去一个很大的数,这里是 1e12。这样在经过 logsumexp 计算后,正类的概率会趋近于 0,达到屏蔽正类的目的。3. 将负类位置的预测值设为负无穷 | |
同样采用上述技巧,将负类的预测值减去一个很大的数,使得在经过 logsumexp | |
函数计算后,负类的概率会趋近于 0,达到屏蔽负类的目的。说明:y_true 和 y_pred 的 shape 一致,y_true 的元素非 0 即 1,1 表示对应的类为目标类,0 表示对应的类为非目标类。警告:请保证 y_pred 的值域是全体实数,换言之一般情况下 y_pred | |
不用加激活函数,尤其是不能加 sigmoid 或者 softmax!预测 | |
阶段则输出 y_pred 大于 0 的类。如有疑问,请仔细阅读并理解 | |
参考 1。""" | |
y_pred = (1 - 2 * y_true) * y_pred | |
y_pred_neg = y_pred - y_true * 1e12 #mask pos | |
y_pred_pos = y_pred - (1-y_true) * 1e12 #mask neg | |
#构建 y_pred[..., :1]一样形状的全 0tensor 来替换对应位置的预测值 | |
zeros = torch.zeros_like(y_pred[..., :1]) | |
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) | |
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) | |
neg_loss = torch.logsumexp(y_pred_neg, dim=-1) | |
pos_loss = torch.logsumexp(y_pred_pos, dim=-1) | |
return (neg_loss + pos_loss).mean() | |
y_pred = torch.tensor([[0.9, 0.1, 0.8, 0.3], [0.3, 0.2, 0.5, 0.7], [0.7, 0.6, 0.1, 0.3]]) | |
y_true = torch.tensor([[1, 0, 1, 0], [0, 0, 1, 1], [1, 0, 0, 0]]) | |
loss = multilabel_categorical_crossentropy(y_true, y_pred) | |
print(loss.item()) | |
#1.9803704023361206 |
注意:其只用于硬标签,像 label smoothing, mixup 都不能使用。详见,多标签“Softmax+ 交叉熵”。稀疏版见参考 2。
参考
[2] 稀疏版多标签分类交叉熵损失函数
正文完
发表至: NLP
2023-11-26