word2vec pytorch代码实战总结( 二 )


以下是全部代码
import matplotlib.pyplot as pltimport torchimport numpy as npimport torch.nn as nnimport torch.optim as optimizerimport torch.utils.data as Datadevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")dtype = torch.FloatTensorsentences = ["i am a student ","i am a boy ","studying is not a easy work ","japanese are bad guys ","we need peace ","computer version is increasingly popular ","the word will get better and better "]sentence_list = "".join(sentences).split()# 语料库---有重复单词vocab = list(set(sentence_list))# 词汇表---没有重复单词word2idx = {w: i for i, w in enumerate(vocab)}# 词汇表生成的字典 , 包含了单词和索引的键值对vocab_size = len(vocab)w_size = 2# 上下文单词窗口大小batch_size = 8word_dim = 2# 词向量维度skip_grams = []for word_idx in range(w_size, len(sentence_list)-w_size):# word_idx---是原语料库中的词索引center_word_vocab_idx = word2idx[sentence_list[word_idx]]# 中心词在词汇表里的索引context_word_idx = list(range(word_idx-w_size, word_idx)) + list(range(word_idx+1, word_idx+w_size+1))# 上下文词在语料库里的索引context_word_vocab_idx = [word2idx[sentence_list[i]] for i in context_word_idx]# 上下文词在词汇表里的索引for idx in context_word_vocab_idx:skip_grams.append([center_word_vocab_idx, idx])# 加入进来的都是索引值def make_data(skip_grams):input_data = http://www.kingceram.com/post/[]output_data = []for center, context in skip_grams:input_data.append(np.eye(vocab_size)[center])output_data.append(context)return input_data, output_datainput_data, output_data = make_data(skip_grams)input_data, output_data = torch.Tensor(input_data), torch.LongTensor(output_data)dataset = Data.TensorDataset(input_data, output_data)loader = Data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)class Word2Vec(nn.Module):def __init__(self):super(Word2Vec, self).__init__()self.W = nn.Parameter(torch.randn(vocab_size, w_size).type(dtype))self.V = nn.Parameter(torch.randn(w_size, vocab_size).type(dtype))def forward(self, X):hidden = torch.mm(X, self.W)output = torch.mm(hidden, self.V)return outputmodel = Word2Vec().to(device)loss_fn = nn.CrossEntropyLoss().to(device)optim = optimizer.Adam(model.parameters(), lr=1e-3)for epoch in range(2000):for i, (batch_x, batch_y) in enumerate(loader):batch_x = batch_x.to(device)batch_y = batch_y.to(device)pred = model(batch_x)loss = loss_fn(pred, batch_y)if (epoch + 1) % 1000 == 0:print(epoch + 1, i, loss.item())optim.zero_grad()loss.backward()optim.step()for i, label in enumerate(vocab):W, WT = model.parameters()x, y = float(W[i][0]), float(W[i][1])plt.scatter(x, y)plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')plt.show()
效果图