现代循环神经网络实战:机器翻译( 二 )


该代码定义了一个名为 Vocab 的类,用于构建文本的词汇表 。下面对每个方法和属性进行解释:
init(self, =None, =0, =None):构造函数,创建一个词汇表 。参数是一个包含文本所有单词的列表,是单词在文本中最少出现的次数,是一个保留的单词列表 。构造函数首先使用 () 函数统计单词的出现次数,并将单词按照出现频率从高到低排序 。接着将预先定义的 单词加入到词汇表中,并根据单词出现的顺序构建一个单词到索引的字典 ,同时根据出现频率构建一个索引到单词的列表。对于出现次数小于的单词,直接跳过不加入词汇表中 。最终得到的词汇表中,索引 0 对应的是 ,索引 1 到 n-1 对应的是出现频率较高的单词,索引 n 及之后的值是出现频率较低的单词 。
len(self):返回词汇表的大小,即单词总数 。
(self, ):根据单词或单词列表返回其对应的索引或索引列表 。如果是一个单词,则返回其对应的索引 。如果是一个单词列表,则对其中的每个单词递归调用 () 方法,并返回索引列表 。
(self, ):根据索引或索引列表返回其对应的单词或单词列表 。如果是一个索引,则返回其对应的单词 。如果是一个索引列表,则对其中的每个索引递归调用 () 方法,并返回单词列表 。
unk:属性,返回 单词的索引,即 0 。
:属性,返回单词出现频率从高到低排序的列表 。该属性在创建词汇表时被初始化 。
():定义在 Vocab 类外部的函数,用于计算单词出现频率 。参数是一个包含文本所有单词的列表,可能是一个二维列表 。如果是一个二维列表,则首先将其转化为一个一维列表 。该函数使用 .() 函数统计每个单词在文本中出现的次数,并返回一个字典 。
src_vocab = Vocab(source, min_freq=2,reserved_tokens=['', '', ''])len(src_vocab)
填充
#@savedef truncate_pad(line, num_steps, padding_token):"""截断或填充文本序列"""if len(line) > num_steps:return line[:num_steps]# 截断return line + [padding_token] * (num_steps - len(line))# 填充truncate_pad(src_vocab[source[0]], 10, src_vocab[''])
组织数据集
#@savedef build_array_nmt(lines, vocab, num_steps):"""将机器翻译的文本序列转换成小批量"""lines = [vocab[l] for l in lines]#token to idlines = [l + [vocab['']] for l in lines]# 加上eos代表结束array = torch.tensor([truncate_pad(l, num_steps, vocab['']) for l in lines])# 转换为数组valid_len = (array != vocab['']).type(torch.int32).sum(1)#有效长度return array, valid_len
这个函数将机器翻译的文本序列转换为小批量,其中输入lines是一个包含文本序列的列表,vocab是一个词汇表对象,是每个序列中包含的最大令牌数 。该函数返回两个张量,第一个是包含个令牌的序列的张量,第二个是每个序列的有效长度 。对于一个序列而言,它的有效长度是指从开头算起,第一个填充符之前的所有标记的数量 。
函数首先将每个文本序列转换为词汇表中的整数列表,然后将每个序列追加一个标记 。接下来,它将所有序列截断或填充为长度为 。如果一个序列在个令牌之后仍有令牌,那么它的其余令牌将被截断 。如果一个序列少于个令牌,那么填充令牌将被添加到序列的末尾 。最后,函数计算每个序列的有效长度,并返回两个张量 。
#@savefrom torch.utils import datadef load_array(data_arrays, batch_size, is_train=True):"""Construct a PyTorch data iterator.Defined in :numref:`sec_linear_concise`"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)def load_data_nmt(batch_size, num_steps, num_examples=600):"""返回翻译数据集的迭代器和词表"""text = preprocess_nmt(read_data_nmt())source, target = tokenize_nmt(text, num_examples)src_vocab = Vocab(source, min_freq=2,reserved_tokens=['', '