Processing math: 100%

9. GlobalPointer 和 Efficient GlobalPointer 原理

1. GlobalPointer

GlobalPointer 来自于论文 GLOBAL POINTER: NOVEL EFFICIENT SPAN-BASED APPROACH FOR NAMED ENTITY RECOGNITION。也可阅读参考 1 来辅助理解。开源代码地址, Pytorch 版本地址。GlobalPointer 思路是:

  1. 对于可以嵌套的 NER,我们如果假定要识别的文本长度为 n, 简单地假定只有一个实体,那么这个文本序列有多少个候选的实体呢?答案是 n(n+1)/2 个。即长度为 n 的序列有 n(n+1)/2 个子序列
  2. 在上面 n(n+1)/2 个子序列选取真正的 k 个实体,这就变成了 n(n+1)/2k 的多标签分类问题。
  3. 进一步来讲,如果有 m 种实体呢,就变成了 mn(n+1)/2k 的多标签分类问题

GlobalPointer 主要特点有:

  1. 区间预测
  2. ROPE 旋转式位置编码
  3. 平衡的多标签分类损失

Approach

问题定义:NER 任务是从文本序列中抽取实体区间和识别其对应的类型。

假设 S=[s1,s2,,sm] 是文本序列,片段 s[i:j] 是一个区间,NER 目标就是识别 sE, 其中 E 是实体类型。

9. GlobalPointer 和 Efficient GlobalPointer 原理

如上图 1 所示,Global Pointer 由两层构成,包括 token 表示区间预测

  1. token 表示:给定一个输入 n 个字符的句子 X=[x1,x2,,xn],经过 PLM 像 BERT 后能得到一个新的 token 的表示矩阵 HRn×v, 这里 v 是表示的维度。具体来说,将文本输入到 PLM 后得到 n 个词 v 维的表示,如 BERT 输入文本后取 last_hidden_state 就得到 n×768 的矩阵,用式子表示为:

h1,h2,hn=PLM(x1,x2,,xn)

  1. 区间预测:

接下来对获取到的句子表示 H, 计算其区间。作者使用两层前馈层来表示区间起始和结束的位置。
qi,α=Wq,αhi+bq,α ki,α=Wq,αhi+bq,α
其中 qi,αRd, ki,αRd 是用来识别实体类型 α 的 token 的向量表示。具体来说,是对类型 α 片段 s[i:j] 的起始和结束位置的表示为 qi,αki,α

那么片段 s[i:j]α 实体的分数可以计算为:
sα(i,j)=qTi,αki,α
另外为了处理边界信息引入了 旋转式位置编码 Rotary Position Embedding。其实就是一个变换矩阵

Ri,满足 RiRj=Rji。引入到式 3 有:

sα(i,j)=(Riqi,α)(Rjkj,α)=qi,αRiRjkj,α=qi,αRjikj,α
这样就注入了相对位置信息。

另外还引入了 多标签分类损失

代码

Pytorch 实现如下代码:

  1. qw,kw就是式 2 中的 qi,αRd, ki,αRd,实际上式 PLM 的 token 表示经过线性映射成 实体数目 x 2inner_dim(一般 bert 选 64)。然后拆分堆叠切片得到(batch_size, seq_len, ent_type_size, inner_dim)shape 的 tensor。
  2. 对 qw,kw 引入 ROPE
  3. 爱因斯坦求和得到分数 logits。
class GlobalPointer(nn.Module):
def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True):
# encodr: RoBerta-Large as encoder
# inner_dim: 64
# ent_type_size: ent_cls_num
super().__init__()
self.encoder = encoder
self.ent_type_size = ent_type_size
self.inner_dim = inner_dim
self.hidden_size = encoder.config.hidden_size
self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)
self.RoPE = RoPE
def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim):
position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)
indices = torch.arange(0, output_dim // 2, dtype=torch.float)
indices = torch.pow(10000, -2 * indices / output_dim)
embeddings = position_ids * indices
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
embeddings = embeddings.to(self.device)
return embeddings
def forward(self, input_ids, attention_mask, token_type_ids):
self.device = input_ids.device
context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
# last_hidden_state:(batch_size, seq_len, hidden_size)
last_hidden_state = context_outputs[0]
batch_size = last_hidden_state.size()[0]
seq_len = last_hidden_state.size()[1]
# outputs:(batch_size, seq_len, ent_type_size*inner_dim*2)
outputs = self.dense(last_hidden_state)
#上一步已经转换为 bsxseq_lenxent_type_size*inner_dim*2
#按照最后一个维度拆分成 ent_type_size 个 bs x seq_len inner_dim* 2 的 tensor
outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)
# 将其堆叠成:(batch_size, seq_len, ent_type_size, inner_dim*2)
outputs = torch.stack(outputs, dim=-2)
# qw,kw:(batch_size, seq_len, ent_type_size, inner_dim)
qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:]
if self.RoPE:
# pos_emb:(batch_size, seq_len, inner_dim)
pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim)
# cos_pos,sin_pos: (batch_size, seq_len, 1, inner_dim)
cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1)
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1)
qw2 = qw2.reshape(qw.shape)
qw = qw * cos_pos + qw2 * sin_pos
kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1)
kw2 = kw2.reshape(kw.shape)
kw = kw * cos_pos + kw2 * sin_pos
# logits:(batch_size, ent_type_size, seq_len, seq_len)
logits = torch.einsum('bmhd,bnhd->bhmn', qw, kw)
# padding mask
pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len)
logits = logits * pad_mask - (1 - pad_mask) * 1e12
# 排除下三角
mask = torch.tril(torch.ones_like(logits), -1)
logits = logits - mask * 1e12
return logits / self.inner_dim ** 0.5

参考

[1] GlobalPointer:用统一的方式处理嵌套和非嵌套 NER]

[2] Efficient GlobalPointer:少点参数,多点效果

 
正文完
 
admin
版权声明:本站原创文章,由 admin 2023-11-26发表,共计4605字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请联系tensortimes@gmail.com。
评论(没有评论)
验证码