Mechanistic interpretability aims to reverse-engineer the algorithms learned by neural networks to make them understandable to humans. But while there is now a fair amount of evidence that neural networks learn interesting algorithms much of the time, there is no reason to think that the algorithms learned will be similar to the algorithms that we are used to1. This post aims to answer what properties make an algorithm 1) possible for the transformer architecture to realize in weights and 2) preferred over other algorithms by the optimization procedure.

Thesis: We're very accustomed to algorithms that are sequential, centralized, tightly coupled, discrete, and deterministic. But transformers don't like to realize or learn algorithms with any of these properties. Instead, transformers prefer algorithms with many parallel and loosely-coupled computation paths (i.e., an ensemble) that compute smooth and probabilistic functions.

Transformer computation is parallel and loosely-coupled

Parallel and loosely coupled computation is more natural to transformers than sequential computation because of the limitation of constant depth and the inductive biases of the optimization process. Transformers struggle to implement deeply sequential algorithms, which contrasts with programming, where sequential and tightly coupled algorithms are the norm.

Constant depth constrains the realizable space of functions

The constant number of layers of a transformer upper bounds the number of sequential operations a transformer can perform, as well as the number of communication rounds between tokens.

To perfectly simulate a boolean (logic) circuit of length $d$ (requiring $d$ sequential steps), a transformer requires $O(d)$ layers, and to perfectly simulate a Finite State Automata applied to an input string of length $T$ requires a transformer of length $O(\log T)$ (in the general case).

This shows up in addition too! A fixed length transformer cannot perfectly compute multi-digit addition: computing each digit of the result requires as input the carry result from the digit column to the right, creating a chain of communication dependencies.2

SGD prefers short & non-interacting paths

Optimization prefers solutions that are faster and easier to learn, which largely means lower effective dimensionality and shorter paths. This is often explained as being due to the greedy nature of gradient descent or via optimization dynamics when starting from small initialization.

Intuition: Imagine gradient descent is given a choice between explaining a portion of the training set with two simple circuits or with a single more complex circuit. If the two simple circuits aren't strongly coupled, they can be learned independently and still bring down the loss. Since more complicated circuits are "harder" to learn, gradient descent will have a preference for the two loosely coupled circuits.

This means that we expect transformers to prefer a strong form of parallelism: to implement algorithms with many computation graphs that interact minimally. In other words, transformer computation might often look like an ensemble of loosely-related computation paths instead of one central mechanism or many parallel mechanisms that interact sparsely.

There is empirical support for this. A classic paper shows that deep ResNets can be seen as a collection of loosely coupled short paths (evidenced by small error when a layer is removed that smoothly increases as the number is increased, suggesting ensemble-like behavior). And a more recent paper replicates this finding in LLMs (excepting very early and late layers which seem to play a special role across many circuits).

A preference for loosely-coupled algorithms also means a preference for algorithms/sub-networks that are only important for a relatively small subset of the input space. When aiming to explain a new portion of the input space, given a choice between making an existing algorithm more complex and learning a new simple algorithm, SGD prefers the latter.

Implications

  • We should expect NNs to struggle to learn perfect algorithms for highly sequential algorithms, and to approximate with a solution consisting of many loosely coupled paths. This often looks like a bag of heuristics.
  • We should expect self-repair (robustness to layer/feature ablations) due to the robustness of large ensembles (ablating one mechanism among many other loosely-coupled ones that vote for the final result will only partially degrade the output).
    • If you ablate the output of one attention layer that upweights a particular logit, the next attention layer often upweights that logit more than without the ablation, partially repairing the effect of the ablation.
  • The preference towards loosely-coupled mechanisms makes interpretability a lot easier than one might think a priori. Parameter decomposition methods that attempt to decompose networks into these loosely coupled mechanisms are showing promising results.

Transformer computation is smooth

Transformers find it easier to represent functions that are more smooth3, rather than rapidly changing functions. Transformers struggle to represent discrete algorithms, while discrete algorithms are the default in programming.

Intuitively, you need very large weights and/or a lot of neurons (each approximating a tiny region of the function) to approximate something like a step function well. Optimization tends to find smaller weights (due to weight decay, starting from small weights, and simplicity bias), and using a lot of neurons is expensive, so optimization will prefer smoother functions.

Features might not be discrete directions

A central belief in mechanistic interpretability is in the Linear Representation Hypothesis, which states that many (strong version: most/all) features represented in a network are linearly represented (can be decoded with a linear probe, intensity as scaling, composition as addition).

For a while, this was taken to mean features were represented as individual directions in activation space. Now there's plenty of evidence for features of higher dimensions, such as the days of the week being essentially positioned in a circle (in order!), with the representation of "late night Monday" and "early morning Tuesday" both lying between "Monday" and "Tuesday".

I think it's possible that features aren't truly discrete directions, and that if you vary an SAE latent a bit it will still causally impact the model in a slightly different but meaningful way. Another way of putting this is that the geometry of activation space is truly how information is stored, as opposed to viewing activations as a sparse sum of feature intensities.4

Evidence for this:

Given this, I wouldn't be surprised if the default behavior of transformers is to represent information in ranges of meaning over some region.

Implications

Transformer computation is probabilistic

Instead of predicting a single output, transformers aim to model a distribution over outputs. This differs from programming where almost all algorithms have a single, deterministic output. This comes from the training pressure to solve a stochastic problem (next-token prediction on arbitrary text can't be perfectly predicted).

Implications

  • Suppression of logits might be just as important as the original upweighting. There is plenty of evidence that models learn mechanisms to downweight logits in cases of overconfidence. For example, there's a common circuit that downweights high-logit tokens that have appeared earlier in the sequence.
    • That the model learns largely decoupled suppression mechanisms instead of learning one (presumably more complicated) more accurate upweighting mechanism also fits into the narrative of networks as a set of minimally coupled mechanisms.
  • In the context of the ensemble nature of models: the model wants to compute its belief that a particular token is the next token for many tokens, and prefers to do this in a parallel/distributed way. This suggests logits might be computed via largely independent paths.
  • Computation paths can't be understand fully without looking at "competing" paths, as relative magnitudes are what matter.
  • We should expect transformers to model beliefs over which "state" in the data-generating process they're in, representing information useful for future tokens as opposed to purely greedy next-token prediction.
  • Training pressure means patterns that appear more frequently in the training set have an outsized impact on algorithms learned.
  • if the activation of a particular heuristic isn't guaranteed, this might be further reason to expect ensemble-like behavior as a sort of hedging.

Bits I'd like to look at more

  • We expect some bits of the network to be (1) complete noise and spurious correlations, (2) a fair amount to be a big bag of heuristics, and (3) the rest to be perfectly generalizing algorithms. The interpretability of each seems like (1) < (2) < (3). A better understanding of the proportion of each seems useful.

A language of model computation

The current approach to representing and communicating circuits is ad hoc and often imprecise. I think the field would greatly benefit from a semi-formal "language" of model computation that can be used to express model algorithms in a way that reflects what is natural/unnatural for transformers to realize and learn. Defining a good representation space would probably be hard, but it seems almost a pre-requisite if we want to go after the ambitious goal of reverse-engineering most of the algorithms implemented in models. A standardized representation space of model computation would allow for better communication of circuits5, create a "decompilation target" for reverse-engineering, and enable the creation of a shared set of tools6 that seem necessary to scale interpretability past the massively manual stage it's in now.

1

They definitely won't look like Python!

2

To be fair, are there really a ton of >50 digit addition problems in the training set? maybe the reason a different algorithm for addition is sometimes learned is rather due to simplicity bias and trouble representing the discrete nature of the canonical algorithm.

3

I mean smooth in the sense of a relatively small Lipschitz constant where, e.g., a step function requires an infinite Lipschitz constant.

4

This could be wrong, as there do seem to be incentives for models to produce discrete features. Much of language is discrete ("Michael Jordan" either appears or it doesn't in a passage), pressure from superposition to avoid interference might prioritize discrete-ish directions, and perhaps mechanisms that deal with discrete directions are simpler to learn.

5

More precise in terms of effects, more complete in terms of mechanisms, and better intuitions in terms of what types of computation is natural for a model to implement.

6

Tools for visualizing and navigating the many computation paths of a model, verifying the extent to which a proposed representation explains a model's predictions, interventions to better control models, formal verification on properties of network output for a given subset of the input, ...