60行NumPy手搓GPT

本文约24000字 , 建议阅读30分钟
本文我们将仅仅使用60行Numpy[6] , 从0-1实现一个GPT 。
本文原载于尹志老师博客:[1] 。
本文还是来自Jay Mody[2] , 那篇被 手动点赞[3]的GPT in 60 Lines of NumPy[4](已获原文作者授权) 。
LLM大行其道 , 然而大多数GPT模型都像个黑盒子一般隐隐绰绰 , 甚至很多人都开始神秘化这个技术 。我觉得直接跳进数学原理和代码里看看真实发生了什么 , 才是最有效的理解某项技术的方法 。正如的 所说:
这些都是电脑程序 。
这篇文章细致的讲解了GPT模型的核心组成及原理 , 并且用Numpy手搓了一个完整的实现(可以跑的那种) , 读起来真的神清气爽 。项目代码也完全开源 , 叫做[5](pico , 果然是不能再小的GPT了) 。
关于译文几点说明:
翻译基本按照原作者的表述和逻辑 , 个别部分译者做了补充和看法;
在本文中 , 我们将仅仅使用60行Numpy[6] , 从0-1实现一个GPT 。然后我们将发布的GPT-2模型的权重加载进我们的实现并生成一些文本 。
注意:
GPT是什么?
GPT代表生成式预训练( Pre- ) 。这是一类基于[8]的神经网络架构 。Jay 的“GPT3是如何工作的”[9]一文在宏观视角下对GPT进行了精彩的介绍 。但这里简单来说:
译者注:就是一种特定的神经网络结构
类似的GPT-3[10], 谷歌的LaMDA[11]还有的 [12]的大语言模型的底层都是GPT模型 。让它们这么特殊的原因是
根本上来看 , 给定一组提示 , GPT能够基于此生成文本 。即使是使用如此简单的API(input = 文本 ,  = 文本) , 一个训练好的GPT能够完成很多出色的任务 , 比如帮你写邮件[13] , 总结一本书[14] , 给你的起标题[15] , 给5岁的小孩解释什么是黑洞[16] , 写SQL代码[17] , 甚至帮你写下你的遗嘱[18] 。
以上就是宏观视角下关于GPT的概览以及它能够做的事情 。现在让我们深入一些细节吧 。
输入/输入
一个GPT的函数签名基本上类似这样:
def gpt(inputs: list[int]) -> list[list[float]]:# inputs has shape [n_seq]# output has shape [n_seq, n_vocab]output = # beep boop neural network magicreturn output
输入
输入是一些文本 , 这些文本被表示成一串整数序列 , 每个整数都与文本中的token对应:

# integers represent tokens in our text, for example:# text= "not all heroes wear capes":# tokens = "not""all" "heroes" "wear" "capes"inputs =[1,0,2,4,6]
token是文本的小片段 , 它们由某种分词器()产生 。我们可以通过一个词汇表()将映射为整数:

# the index of a token in the vocab represents the integer id for that token# i.e. the integer id for "heroes" would be 2, since vocab[2] = "heroes"vocab = ["all", "not", "heroes", "the", "wear", ".", "capes"]# a pretend tokenizer that tokenizes on whitespacetokenizer = WhitespaceTokenizer(vocab)# the encode() method converts a str -> list[int]ids = tokenizer.encode("not all heroes wear") # ids = [1, 0, 2, 4]# we can see what the actual tokens are via our vocab mappingtokens = [tokenizer.vocab[i] for i in ids] # tokens = ["not", "all", "heroes", "wear"]# the decode() method converts back a list[int] -> strtext = tokenizer.decode(ids) # text = "not all heroes wear"
简单说:
在实际中 , 我们不仅仅使用简单的通过空白分隔去做分词 , 我们会使用一些更高级的方法 , 比如Byte-Pair [19]或者[20] , 但它们的原理是一样的: