二、用于预训练BERT的数据集( 三 )


class NextSentencePred(nn.Module):"""BERT的下一句预测任务"""def __init__(self, num_inputs, **kwargs):super(NextSentencePred, self).__init__(**kwargs)self.output = nn.Linear(num_inputs, 2)def forward(self, X):# X的形状:(batchsize,num_hiddens)return self.output(X)
encoded_X = torch.flatten(encoded_X, start_dim=1)# NSP的输入形状:(batchsize,num_hiddens)nsp = NextSentencePred(encoded_X.shape[-1])nsp_Y_hat = nsp(encoded_X)
上述两个预训练任务中的所有标签都可以从预训练语料库中获得,而无需人工标注 。原始的BERT已经在图书语料库和英文维基百科的连接上进行了预训练 。这两个文本语料库非常庞大:它们分别有8亿个单词和25亿个单词 。
5.整合代码
在预训练BERT时,最终的损失函数是掩蔽语言模型损失函数和下一句预测损失函数的线性组合 。通过实例化三个类、和来定义类 。前向推断返回编码后的BERT表示、掩蔽语言模型预测和下一句预测 。
class BERTModel(nn.Module):"""BERT模型"""def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, hid_in_features=768, mlm_in_features=768, nsp_in_features=768):super(BERTModel, self).__init__()self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=max_len,key_size=key_size, query_size=query_size, value_size=value_size)self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())self.nsp = NextSentencePred(nsp_in_features)def forward(self, tokens, segments, valid_lens=None, pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# 用于下一句预测的多层感知机分类器的隐藏层,0是“”标记的索引nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat
二、用于预训练BERT的数据集
最初的BERT模型是在两个庞大的图书语料库和英语维基百科的合集上预训练的,现成的预训练BERT模型可能不适合医学等特定领域的应用 。因此,在定制的数据集上对BERT进行预训练变得越来越流行 。为了方便BERT预训练的演示,使用较小的语料库-2。
-2与用于预训练的PTB数据集相比,有以下不同:
(1)保留了原来的标点符号,适合于下一句预测;
(2)保留了原来的大小写和数字;
(3)大了一倍以上 。
1.下载并读取数据集
在-2数据集中,每行代表一个段落,其中在任意标点符号及其前面的词元之间插入空格 。为了简单起见,我们仅使用句号作为分隔符来拆分句子,保留至少有两句话的段落 。
import osimport randomimport torchfrom d2l import torch as d2ld2l.DATA_HUB['wikitext-2'] = ('https://s3.amazonaws.com/research.metamind.io/wikitext/''wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')def _read_wiki(data_dir):file_name = os.path.join(data_dir, 'wiki.train.tokens')with open(file_name, 'r') as f:lines = f.readlines()# 大写字母转换为小写字母paragraphs = [line.strip().lower().split(' . ')for line in lines if len(line.split(' . ')) >= 2]random.shuffle(paragraphs)return paragraphs
2.生成下一句预测任务的数据
def _get_next_sentence(sentence, next_sentence, paragraphs):"""生成二分类任务的训练样本"""if random.random()