Understanding neural networks through sparse circuits

Understanding neural networks through sparse circuits

OpenAI News

神经网络驱动着当今最强大的人工智能系统,但它们仍然难以解释。我们并不是用明确的逐步指令去“编写”这些模型,而是通过调整数十亿条内部连接——也就是所谓的 “ weights ”——让模型在特定任务上达到精通。我们设计训练规则,但并不直接决定模型会表现出哪些具体行为,最终形成的是一张普通人难以拆解的密集连接网。

我们如何看待可解释性

随着 AI 系统能力提升,并在科学、教育和医疗等实际决策中产生影响,弄清它们是如何工作的就变得至关重要。所谓可解释性,指的是帮助我们理解模型为何给出某个输出的方法。达成这一目标的途径有很多。

例如,一些推理型模型被激励在给出最终答案的过程中解释自己的推理。所谓的 “ Chain of thought interpretability ” 就利用这些解释来监测模型行为。这类方法立刻就有用:目前的推理模型在链式思路中提供的信息,对于识别像欺骗之类的令人关切的行为是有帮助的。但完全依赖这一性质是脆弱的,随着时间推移可能会失效。

另一类,本文所聚焦的则是 “ mechanistic interpretability ”,它试图对模型的计算过程进行彻底的逆向工程。到目前为止,这一方向的即时实用性不如链式解释,但理论上能够为模型行为提供更完整的解释。因为它追求在最细粒度上说明模型行为,假设更少、能带来更高的确信度。但要把低层次的细节串联成对复杂行为的解释,路途要远得多也难得多。

可解释性支持若干关键目标,比如便于更好的监督、以及为不安全或策略上不对齐的行为提供早期预警。它也与我们的其他安全工作互为补充,比如 “ scalable oversight ”、 “ adversarial training ” 和 “ red-teaming ”。

在这项工作中,我们展示了可以通过训练方式让模型更容易被解释。我们将这看作对密集网络事后分析的有希望的补充。

这是一项非常雄心勃勃的尝试:从现在的工作到完整理解最强模型的复杂行为还有很长的路要走。不过,对于简单行为,我们发现用我们的方法训练出的稀疏模型(即 “ sparse models ”)通常包含一些小而解耦的电路,这些电路既可理解又足以完成该行为。这表明,或许存在一条可行路径,可以训练出规模更大且其内部机制更易被理解的系统。

一种新方法:学习稀疏模型

以往的“机制可解释性”研究通常从密集、缠结的网络入手,试图把它们解开。在那些网络里,每个神经元连接到成千上万个其他神经元;大多数神经元似乎承担着多种不同功能,令人难以理解。

但如果我们从一开始就训练“未被缠结”的神经网络——层数或者神经元更多,但每个神经元只保留几十条连接呢?也许结果网络会更简单,更容易理解。这就是我们研究的核心假设。

基于这一理念,我们训练了与现有语言模型类似架构的模型,例如 GPT‑2 ,但做了一点小改动:我们将绝大多数权重强制设为零。也就是说,模型只能使用神经元之间极少数可能的连接。这个简单改动显著地使内部计算解耦,变得更易分析。

(图示对比密集电路与稀疏电路:密集版显示两行节点之间大量连线;稀疏版则在相同布局下仅保留较少且更有选择性的连接。)

在传统的密集神经网络里,每个神经元会连到下一层的每个神经元;而在我们的稀疏模型中,每个神经元仅连到下一层的少量神经元。我们希望这能让单个神经元及整个网络都更容易被理解。

评估可解释性

我们希望衡量这些稀疏模型的计算在多大程度上被解耦。为此我们挑选了若干简单的模型行为,检验是否能把负责每种行为的模型部分隔离出来——我们称之为 “ circuits ”(电路)。

我们手工挑选了一套简单的算法任务。对每个任务,我们将模型剪枝到仍能完成该任务的最小电路,并考察该电路有多简单。(详情见我们的论文,链接由原文提供。)结果显示,通过训练更大且更稀疏的模型,我们可以得到能力更强且其电路更简单的模型。

(散点图:横轴为模型能力( “ pretraining loss ”),纵轴为可解释性(修剪后电路规模)。点表示不同尺寸与稀疏程度的模型,颜色表示参数总量,点大小表示非零参数数量。右上方向标注为“更好”。)

我们将可解释性与能力绘于同一坐标系(左下角更好)。在固定的稀疏模型规模下,提高稀疏度——即设更多权重为零——会降低能力但提升可解释性。而通过扩大模型规模,这一前沿可以向外推进,暗示我们可以构建既强大又更易解释的更大模型。

举个具体例子:考虑一个在 Python 代码上训练的模型,它要补全一个字符串时使用哪种引号。在 Python 里,‘hello’ 必须以单引号结尾,“hello” 则以双引号结尾。模型可以通过记住字符串起始时使用的引号类型,并在末尾复现它来解决这个任务。

我们最易解释的模型似乎包含了解耦的电路,恰好实现了这个算法。

(图示举例:稀疏 transformer 电路展示了特定神经元和注意力头如何对输入标记如“(”和“circuits”激活,标注了正负权重、乘法、非线性,以及 MLP 与注意力层之间的连接,最终形成输出标记概率。)

在这个例子里,一个稀疏 transformer 的电路只用了五条残差通道(“ residual channels ”)、第 0 层的两个 MLP 神经元(“ MLP ”)、以及第 10 层的一个查询-键通道和一个值通道(“ Q/K/V ” 指标栏目)。模型的工作流程是:(1)把单引号编码到一条残差通道,把双引号编码到另一条;(2)用 MLP 层把它们转换成一条检测任何引号的通道和一条区分单双引号的通道;(3)用注意力操作忽略中间的标记,定位之前的引号并复制其类型到结尾;(4)预测匹配的闭合引号。

按我们的定义,上述那些具体连接足以完成任务——把模型其余部分删掉,这个小电路仍能工作;它们也是必要的——删除这些边中的任意若干会导致模型失败。

我们也研究了一些更复杂的行为。对于这些行为(例如下文展示的变量绑定),对应的电路更难以完全解释。即便如此,我们仍能得到相对简单的部分解释,这些解释对预测模型行为是有效的。

(另一图示:在名为 get_neighbors 的 Python 函数中,突出显示稀疏-transformer 的一个电路。两个对 current = set() 的赋值被框出,彩色箭头标示激活并连接每个变量 current 出现处与循环中使用该变量的注意力头(用 “ Q/K/V ” 索引标注)。)

另一个例子(较少细节):为判断名为 current 的变量类型,一个注意力操作在定义时把变量名复制进 set() 标记,随后另一个注意力操作又把 set() 标记中的类型复制到变量的后续使用处,从而让模型推断出正确的下一个标记。

前路

这项工作是朝着让模型计算更易理解的更大目标迈出的早期一步,但路仍漫长。我们的稀疏模型规模远小于前沿模型,且大部分计算仍未被解释清楚。

下一步,我们希望把技术扩展到更大的模型,并解释更多模型行为。通过枚举那些支撑更复杂推理的电路模式(circuit motifs),我们或能形成一套理解框架,用以更有针对性地研究最先进的模型。

为克服训练稀疏模型的低效性,我们看到两条可行路径。一是从现有的密集模型中抽取稀疏电路,而不是从零开始训练稀疏模型——因为密集模型在部署上本质上更高效。二是开发更高效的训练技术,使模型更易解释,从而更容易投入生产环境。

需要强调的是,我们在此的发现并不能保证这种方法必然适用于更强大的系统,但早期结果令人鼓舞。我们的目标是逐步扩大可可靠解释的模型范围,并构建能让未来系统更易分析、调试和评估的工具。



Neural networks power today’s most capable AI systems, but they remain difficult to understand. We don’t write these models with explicit, step-by-step instructions. Instead, they learn by adjusting billions of internal connections, or “weights,” until they master a task. We design the rules of training, but not the specific behaviors that emerge, and the result is a dense web of connections that no human can easily decipher. 


How we view interpretability




As AI systems become more capable and have real-world impact on decisions in science, education, and healthcare, understanding how they work is essential. Interpretability refers to methods that help us understand why a model produced a given output. There are many ways we might achieve this. 


For example, reasoning models are incentivized to explain their work on the way to a final answer. Chain of thought interpretability leverages these explanations to monitor the model’s behavior. This is immediately useful: current reasoning models’ chains of thought seem to be informative with respect to concerning behaviors like deception. However, fully relying on this property is a brittle strategy, and this may break down over time.


On the other hand, mechanistic interpretability, which is the focus of this work, seeks to completely reverse engineer a model’s computations. It has so far been less immediately useful, but in principle, could offer a more complete explanation of the model’s behavior. By seeking to explain model behavior at the most granular level, mechanistic interpretability can make fewer assumptions and give us more confidence. But the path from low-level details to explanations of complex behaviors is much longer and more difficult.


Interpretability supports several key goals, for example enabling better oversight and providing early warning signs of unsafe or strategically misaligned behavior. It also complements our other safety efforts, such as scalable oversight, adversarial training, and red-teaming. 


In this work, we show that we can often train models in ways that make them easier to interpret. We see our work as a promising complement to post-hoc analysis of dense networks. 


This is a very ambitious bet; there is a long path from our work to fully understanding the complex behaviors of our most powerful models. Still, for simple behaviors, we find that sparse models trained with our method contain small, disentangled circuits that are both understandable and sufficient to perform the behavior. This suggests there may be a tractable path toward training larger systems whose mechanisms we can understand.


A new approach: learning sparse models




Previous mechanistic interpretability work has started from dense, tangled networks, and tried to untangle them. In these networks, each individual neuron is connected to thousands of other neurons. Most neurons seem to perform many distinct functions, making it seemingly impossible to understand. 


But what if we trained untangled neural networks, with many more neurons, but where each neuron has only a few dozen connections? Then maybe the resulting network will be simpler, and easier to understand. This is the central research bet of our work.


With this principle in mind, we trained language models with a very similar architecture to existing language models like GPT‑2, with one small modification: we force the vast majority of the model’s weights to be zeros. This constrained the model to use only very few of the possible connections between its neurons. This is a simple change which we argue substantially disentangles the model’s internal computations.





In normal dense neural networks, each neuron is connected to every neuron in the next layer. In our sparse models, each neuron only connects to a few neurons in the next layer. We hope that this makes the neurons, and the network as a whole, easier to understand.











Evaluating interpretability




We wish to measure the extent to which our sparse models’ computations are disentangled. We considered various simple model behaviors, and checked whether we could isolate the parts of the model responsible for each behavior—which we term circuits.


We hand-curated a suite of simple algorithmic tasks. For each, we pruned the model down to the smallest circuit that can still perform the task, and examined how simple that circuit is. (For details, see our paper⁠.) We found that by training bigger and sparser models, we could produce increasingly capable models with increasingly simple circuits.





We plot interpretability versus capability across models (lower-left is better). For a fixed sparse model size, increasing sparsity—setting more weights to zero—reduces capability but increases interpretability. Scaling up model size shifts this frontier outward, suggesting we can build larger models that are both capable and interpretable.











To make this concrete, consider a task where a model trained on Python code has to complete a string with the correct type of quote. In Python, ‘hello’ must end with a single quote, and “hello” must end with a double quote. The model can solve this by remembering which quote type opened the string and reproducing it at the end.


Our most interpretable models appear to contain disentangled circuits which implement exactly that algorithm.





Example circuit in a sparse transformer that predicts whether to end a string in a single or double quote. This circuit uses just five residual channels (vertical gray lines), two MLP neurons in layer 0, and one attention query-key channel and one value channel in layer 10. The model (1) encodes single quotes in one residual channel and double quotes in another; (2) uses an MLP layer to convert this into one channel that detects any quote and another that classifies between single and double quotes; (3) uses an attention operation to ignore intervening tokens, find the previous quote, and copy its type to the final token; and (4) predicts the matching closing quote.











In our definition, the exact connections shown above are sufficient to perform the task—if we remove the rest of the model, this small circuit still works. They are also necessary–deleting these few edges causes the model to fail.


We also looked at some more complicated behaviors. Our circuits for these behaviors (for example variable binding shown below) are harder to explain completely. Even then, we can still achieve relatively simple partial explanations which are predictive of model behavior.





Another example circuit, in less detail. To determine the type of a variable called current, one attention operation copies the variable name into the set() token when it’s defined, and another later operation copies the type from the set() token into a subsequent use of the variable, allowing the model to infer the correct next token.











The road ahead




This work is an early step toward a larger goal: making model computations easier to understand. But, there’s still a long way to go. Our sparse models are much smaller than frontier models, and large parts of their computation remain uninterpreted. 


Next, we hope to scale our techniques to larger models, and to explain more of the models’ behavior. By enumerating circuit motifs underlying more complex reasoning in capable sparse models, we could develop an understanding that helps us better target investigations of frontier models.


To overcome the inefficiency of training sparse models, we see two paths forward. One is to extract sparse circuits from existing dense models, rather than training sparse models from scratch. Dense models are fundamentally more efficient to deploy than sparse models. The other path is to develop more efficient techniques to train models for interpretability, which might be easier to put in production.


Note that our findings here are no guarantee that this approach will extend to more capable systems, but these early results are promising. Our aim is to gradually expand how much of a model we can reliably interpret, and to build tools that make future systems easier to analyze, debug, and evaluate.



Generated by RSStT. The copyright belongs to the original author.

Source

Report Page