DL-Paper精读:LSTM + Transformer 架构模型

with LSTM-based Cross-
?
近来 , 源于某个神奇的需求 , 需要研究和LSTM相结合的模型架构 。这两者作为自然语言领域两个时代的王者 , 似乎对立的戏份远大于合作 。常理来说 , 在刚刚被提出来的一两年内 , 应该有很多关于这方面的研究工作 , 但很奇怪地是并未搜索到比较出名的工作 。难道是这两者组合效果不佳 , 水火不容?这篇文章是收录于的一个工作 , 旨在将LSTM结合到结构中 , 通过一种交叉的信息表达 , 来获得更强大更鲁棒的语言模型 。
对该工作的研究 , 主要集中在其网络架构的设计和代码的实现方面 。由于对于语言方面的不了解 , 不太清楚文中所给出的0.9%, 0.6% and 0.8%WERon AMI 代表怎样的意义 。
文中针对常见的(TLM)和TLM-XL(一种使用分段递归来实现超长序列预测的方法)进行改造 , 具体结构如下 。TLM的核心部分是重复的模块 , 由多头自适应( MHA)和FFN模块组成 。而TLM-XL的区别在于 , 在计算MHA时将上个block的输入与本次的输入进行 , 共同计算 。

DL-Paper精读:LSTM + Transformer 架构模型

文章插图
文中提出的LSTM + TLM架构(称为R-TLM)如下 , 网络其他部分不未进行改动 , 主要是在MHA模块的前端插入了LSTM模块 , 对于输入X , 首先通过LSTM进行处理 , 输出与原输入进行一个之后 , 再作为输入传到MHA中 , 执行政策的 block的操作 。其中LSTM的h/c等使用上一个block的输出 。
对于R-TLM架构的优势 , 文中解释如下:
在测试re-score阶段 , 对于 past (对于实在理解不够)中的单词错误 , LSTM模块的隐含信息能够有效缓解其影响 , 提高鲁棒性;
R-TLM能够同时提供LSTM和模块所提供的补充历史信息和基于注意力的信息表示 , 提高了模型能力;
R-TLM作为两种模型的结合体 , 能够解决单个模型在不同大小数据集上性能表现不同的问题(一般LSTM在小数据集上比表现更好 , 但在预训练后表现很突出) 。

LSTM
DL-Paper精读:LSTM + Transformer 架构模型

文章插图
if rnnenc and rnndim != 0: # merge_type为gating或project时 , 将输入输出进行concat再通过linear层统一维度 if merge_type in ['gating', 'project']: self.rnnproj = nn.Linear(rnndim + d_model, d_model) self.rnn_list = nn.ModuleList([nn.LSTM(d_model, rnndim, 1) for i in range(len(self.rnnlayer_list))])
for i in range(n_layer): # dropatt = dropatt * 2 if (i == 0 and rnnenc) else dropatt use_penalty = i in self.attn_pen_layers self.layers.append( RelPartialLearnableDecoderLayer( n_head, d_model, d_head, d_inner, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm, penalty=use_penalty) )
两者的层数一致 , 也就是说每个 block前分别建立一个LSTM

其中第一层的LSTM被单独拿出来在循环之前进行了计算 。所以每个LSTM的输入为, 经过LSTM模块之后 , 输出的与原来的进行组合作为新的输入喂给下一个模块;而LSTM输出的被作为下一个LSTM的输入 。
for i, layer in enumerate(self.layers):if self.future_len != 0:dec_attn_mask = dec_attn_mask_future if i in self.layer_list else dec_attn_mask_normalmems_i = None if mems is None else mems[i]# Perform Attention Layercore_out, attn_pen = layer(core_out, pos_emb, self.r_w_bias,self.r_r_bias, dec_attn_mask=dec_attn_mask,mems=mems_i)# gs534 - rnn in the middle of a transformer layerif self.rnnenc and i+1 in self.rnnlayer_list:# performe LSTM modulernn_out, rnn_hidden = self.forward_rnn(i+1, core_out, rnn_hidden, stepwise, future_seqlen)# merge input and output of LSTMif self.merge_type == 'project':core_out = torch.relu(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))# core_out = (self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))elif self.merge_type == 'gating':core_gating = torch.sigmoid(self.rnnproj(torch.cat([rnn_out, core_out], dim=-1)))core_out = rnn_out * core_gating + core_out * (1 - core_gating)else:core_out = rnn_outattn_pen_list.append(attn_pen)hids.append(core_out)