夕小瑶科技说 原创
作者 | 智商掉了一地、iven

用 Recurrent Memory Transformer 架构:可输入长度取决于内存大小

Transformer 因其在自然语言处理领域的成功应用而备受瞩目,同时在计算机视觉领域的研究中,诸多的多模态大模型如 ViT、CLIP、BLIP 等也都与 Transformer 的发展密不可分,从而推动着多个领域的进步。然而,Transformer 存在一个关键问题,即其注意力操作的二次复杂度,这导致将大模型应用于处理较长序列变得越来越困难。然而,通过利用特殊的记忆 token 实现记忆机制的 Recurrent Memory Transformer(RMT)模型,有效上下文长度能够增长到百万级,这带来了新的发展前景。

最近,RMT 的三位作者发表了一篇新论文,指出使用 RMT 可以将 Transformer 的有效上下文扩展到 200 万个 token。试想,如果未来内存允许则可以打破 Transformer 长度的限制。这将是一项令人兴奋的发展,相信会对 NLP 和诸多相关领域的研究产生重要影响。

论文题目:
Scaling Transformer to 1M tokens and beyond with RMT

论文链接:
https://arxiv.org/abs/2304.11062

代码地址
https://github.com/booydar/t5-experiments/tree/scaling-report

回顾 Recurrent Memory Transformer

记忆增强的分段级递归 Transformer(Recurrent Memory Transformer,RMT)的结构如图 1 所示,由于记忆机制允许存储和处理局部和全局信息,并通过递归在长序列的段之间传递信息。RMT的实现不需要改变 Transformer 模型,只需对模型的输入和输出序列进行修改即可实现记忆和循环。然后,模型被训练来控制记忆操作和序列表示处理。递归记忆 Transformer 是一种有前途的架构,适用于需要学习长期依赖关系和通用内存处理的应用,例如算法和推理任务。该模型在语言建模等任务上与 Transformer-XL 表现相当,甚至在需要处理更长序列的任务上表现更好。

具体来说,RMT 通过将记忆作为 token 添加到输入序列中,记忆 token 可为模型提供额外的容量,以处理与输入序列中任何元素无直接关联的信息。为了处理长序列,将其分割成段并将上一段的记忆状态传递给当前段。在训练期间,梯度从当前段通过记忆流向前一段。

▲图1 Recurrent Memory Transformer 模型图

图 2 展示了对于不同规模和序列长度的 RMT 和 Transformer 模型所需要的 FLOPs(浮点操作数)的估计。它们的配置信息(如词汇量大小、层数、隐藏层大小、中间隐藏层大小和注意力头数等)来自 OPT 模型家族。

▲图2 RMT推理与输入序列长度呈线性关系

结果显示,如果分段长度固定,RMT的规模随着任何模型大小的增加线性扩展。而大型 Transformer 模型则较少能实现线性扩展,更长的序列长度将导致 FLOPs 数呈现二次增长。同时,RMT模型在含有多个片段的序列中需要的 FLOPs 要少于非循环模型(包括 Transformer),能够减少 FLOPs 数高达 295 倍。总体而言, RMT 模型在处理包含许多片段的大型序列时可能比非循环模型更有效率

记忆任务(Memorization Tasks)

为了测试记忆能力,作者构建了需要记忆简单事实和基本推理的合成数据集。任务输入由一个或多个事实和一个问题组成,该问题只能通过使用所有这些事实来回答。为增加任务难度,还添加了与问题或答案无关的自然语言文本,该文本起到了加入噪声的作用。

所以模型的任务是将事实从无关的文本中分离出来,并用它们来回答问题。该任务被公式化为 6 个分类,每个类别代表一个单独的答案选项。

如图 3 所示,提出了综合任务和解决这些任务所需的 RMT 操作。在记忆任务中,事实陈述被放在序列的开头,在检测和记忆任务中,事实被随机放置在文本序列中,从而使检测更具挑战性。在推理任务中,提供答案所需的两个事实被随机放置在文本中。对于所有任务,问题都在序列的结尾。其中,‘mem’表示记忆符号,‘Q’表示问题,‘A’表示答案。

▲图3 记忆密集型合成任务

事实记忆

第一个任务是测试 RMT 在较长时间内在记忆中写入和存储信息的能力(如图 3 的顶部)。在最简单的情况下,事实始终位于输入的开头,而问题始终位于结尾。问题和答案之间的无关文本的数量逐渐增加,因此整个输入不适合单个模型输入。

  • Fact: Daniel went back to the hallway

  • Question: Where is Daniel” />Answer: hallway

事实检测 & 记忆

事实检测通过将事实移动到输入中的随机位置来增加任务难度(如图 3 的中间部分)。这要求模型首先将事实与不相关的文本区分开来,将其写入记忆,然后使用它来回答位于结尾的问题。

用记忆中的事实进行推理

另一个与记忆有关的重要操作是利用记忆的事实和当前上下文进行推理。为了评估这个函数,使用了一个更复杂的任务(如图 3 的底部),其中生成两个事实并在输入序列中随机定位。在序列结尾提出的问题以这样一种方式表述,即任何事实都必须用于正确回答问题

  • Fact1: The hallway is east of the bathroom.

  • Fact2: The bedroom is west of the bathroom.

  • Question: What is the bathroom east of?

  • Answer: bedroom

实验及分析

渐进学习

最初 RMT 在任务的较短版本上进行训练,在训练收敛时,通过增加一个片段来增加任务长度。渐进学习的过程将持续至达到所需的输入长度。

在实验中从适合单个片段的序列开始,实际的段大小为 499,因为从模型输入中保留了 3 个 BERT 特殊 token 和 10 个记忆占位符,大小共为 512。

外推能力

作者对在不同输入长度的 1-7 段任务上训练的 checkpoint 进行评估,图 4(a)是记忆任务的结果,图 4(b)是检测 & 记忆的结果,图 4(c)是推理的结果

▲图4 记忆检索的泛化

研究表明,在不同的序列长度上,RMT 的泛化能力表现出不同的效果。较短序列的任务中模型表现较好,但当模型训练的序列较长时,单段推理任务变得难以解决。可能的解释是,任务长度超过一个段落,模型会停止期望在第一个段落中看到问题,导致质量下降。但有趣的是,RMT 的泛化能力随着训练段数的增加而提高。在训练了 5 个或更多段落后,RMT 可以近乎完美地泛化两倍长度的任务。在验证任务长度增加至 4096 个段落或 2043904 个令牌时,RMT 在这样长的序列上仍然表现出惊人的性能,其中“事实检测 & 记忆”是最容易的任务,而“事实推理”任务则最为复杂。

记忆操作的注意力模式

通过对图 5 中特定片段上的 RMT 注意力分析,观察到记忆操作对应着注意力中的特定模式。此外,在处理极长的序列时,学习的记忆操作表现出了很强的泛化能力,即使执行成千上万次仍然有效。这些操作并没有明确地受到任务损失的影响,这令人印象深刻。

▲图5 记忆操作的注意力图

小结

有人指出[1],如果将这种技术迁移到 GPT4 等大模型上,未来用 RMT 架构来处理超过 100 万个 token,那么将会解锁自然语言处理乃至人工智能领域的一些新应用,为整个行业带来巨大发展潜力和机会。通过将 Transformer 模型扩展到处理更长的文本序列,可以带来更准确、更自然、更高效的自然语言处理应用,从而改善我们的生活和工作体验,比如:

  • 增强对上下文的理解:随着上下文长度的增加,模型将能够理解和处理更长的文本序列,提高其推理能力和在文本的远距离建立联系的能力。

  • 提高长文本任务的性能:需要处理长文本的任务,如摘要、翻译或对长文档的问答。

  • 更大的创造力潜力:有了更多的上下文,模型可以生成更连贯、更吸引人的长篇内容,比如写小说、剧本或研究论文。

  • 更高级的聊天机器人:改进的上下文理解将允许更自然和上下文感知的会话代理,因为模型将能够更有效地回忆起对话的早期部分。

  • 增强的多模态应用:更长的上下文长度将使模型能够更好地处理和理解复杂的多模态输入,例如处理用于内容生成或分析的文本、图像和视频的组合。

  • 更好的迁移学习能力:具有更大上下文窗口的模型可能在不同任务和领域之间更有效地学习和迁移知识,从而提高整体性能和适应性。

  • 海量数据的实时处理能力:由于能够处理大量 token,该模型可用于大规模数据集的实时处理,例如监控社交媒体或分析科学文献以获取特定趋势或信息。

然而,这一技术的实现需要大量的计算资源和大量的时间和精力来开发和调整模型,同时还需要考虑潜在的道德和风险问题。因此,我们必须认真对待这一技术,并在开发过程中积极探索如何解决其中的问题,以便最大程度地发挥其潜力。同时,也需要在应用这一技术的过程中保持警惕,避免生成偏见或有害内容,以确保我们能够在未来受益于它。

参考资料

[1]

twitter: https://twitter.com/promptsurfer/status/1650401661880516608