Large Language Models process input with an ordered sequence of layers. Let $l_i(x)$ denote the result of calling the $ith$ layer on some input $x$; The value returned by a model with three layers is denoted $M(x) = l_3(l_2(l_1(x)))$.
In this essay, I analyze the output of $l_{12}$ (of $32$) of Meta's seven billion parameter LLM (LLaMA-7b) using techniques from The Geometry of Truth: Emergent Linear Structure in Large Language Model Representations of True/False Datasets by Samuel Marks and Max Tegmark.
I curate (code) datasets of questions about math problems (for example, "What is $2/2+6$?"). I send those statements through LLaMA-7b and sample the output of $l_{12}$. This output is high-dimensional and difficult to visualize, so I perform principal component analysis on the samples to find a two-dimensional representation of $l_{12}$'s output.
In this two-dimensional representation LLaMA-7b groups math problems according to their ordering of operations.
Figure 1. Each point corresponds to a question (full list) in the form "What is $\circ$", where $\circ$ is a math problem in the form $a+b-c$ or $a-b+c$ resulting in either $4$ or $7$. Each question is sent through LLaMA-7b and the output of $l_{12}$ is sampled. PCA is then performed on the samples, and each sample is projected onto the top two principal components to create this plot. The horizontal axis corresponds to the first principal component, and the vertical one to the second.
There is further separation for problems that begin with a negative number. And mysteriously, the top right cluster in the plot above corresponds to problems ending in $+1$. For example, "What is $(9-3)+1$?".
In Anthropic's research on Transformer mathematics, they describe how "layers can send different information to different layers by storing it in different subspaces [of the residual stream]." Is that what we see here in the separation of math problems by their order of operations? Do future layers know to look in these subspaces and how to evaluate problems within?
In the plot below, I expand the earlier plot to include a larger
variety of math problem types. In the legend, $\circ_1\circ_2$ is
shorthand for a problem with the structure $a \circ_1 b \circ_2
c$. For example $(1+2)/3$ is classified as +/
.
Figure 2. Each point corresponds to a question (full list) in the form "What is $\circ$", where $\circ$ is a math problem in the form $a \circ_1 b \circ_2 c$ labeled $\circ_1\circ_2$ resulting in either $4$ or $7$. For example, a point labeled $+-$ corresponds to a problem in the form $a+b-c$. Two-dimensional representations of questions are found using the methods described in Figure 1.
As Raul Molina writes in Traveling Words: A Geometric Interpretation of Transformers, "layer normalization confines the latent features to a hyper-sphere". Is the circular shape of the plot due to that hypersphere?
There's deep mystery here. Consider what happens when reproducing Samuel Marks and Max Tegmark's experiment, by inputting true and false statements about Spanish to English translation and sampling the output of layer 12. An example true statement in the below plot is "The Spanish word 'estudiante' means 'student." An example false one, "The Spanish word 'dos' means 'enemy'."
Figure 3. Each point corresponds to a statement about Spanish-to-English translation (full list) in the form "The Spanish word '$\circ_1$' means '$\circ_2$'." $\circ_1$ is a Spanish word, and $\circ_2$ an English word or phrase. Points are labeled according to if $\circ_2$ is $\circ_1$'s English translation. Two-dimensional representations of statements are found using the same method as Figure 1.
In the case of math problems, the model separates statements based on the structure. Here, it separates statements based on whether they are true or false! Does the model have an encoding of capital-T Truth? Unlikely, but how then is it performing this separation? What structure is found in true/false statements?
Figure 4. The same data and representation as Figure 2, but labeled according to the result of the math problem. Evidently, LLaMA-7b does not separate math problems based on their result.
Why are true and false statements separated by their result, while math problems are separated by their structure?
As a final confounding point, consider the result of running the same experiment on the second-to-last layer of a small neural network trained to classify handwritten digits (code).
When images corresponding to the handwritten digits four and seven (or any arbitrary pair of digits) are used to create the samples, there is clear separation.
Figure 5. Each point corresponds to an image of a handwritten four or seven from the MNIST dataset. Points are sent through a two-layer neural network trained to classify them, and projected onto their top two principal components as in Figure 1.
Figure 6. Each point corresponds to an image of a handwritten digit from the MNIST dataset. Points are sent through a two-layer neural network trained to classify them, and projected onto their top two principal components as in Figure 1. As with math problem structure, the separation becomes less clear when considering additional classifications, though hints remain.
This handwritten digit classifier is much less complex than LLaMA-7b. Yet, like true/false statements, it separates handwritten digits based on their classification when projected onto the principal components. Does this mean separation on the principal components is something fundamental to the operation of neural networks?
There is so much mystery here! And it is fascinating to glimpse how LLaMA-7b thinks about math problems. Inside LLaMA's representations of human concepts, I suspect there are deep truths about the ontology of human speech and knowledge.
Claim: The fraction of the variance retained by a mean-centered, $n\times m$ matrix $X$ containing $n$ samples with $m$ features projected onto its $i$th principal component is
where $\lambda_i$ is the $i$th eigenvalue (of $m$) of the covariance matrix of $X$.
The principal components of $X$ are the eigenvectors of its covariance matrix, $\Sigma$, ordered by their corresponding eigenvalues.
For example, the first principal component corresponds to the
eigenvector of $\Sigma$ with the largest eigenvalue. I like Cosma
Shalizi's proof of
this
(Section 18.1). By mean-centered I mean X - np.mean(X,axis=0,keepdims=True)
in numpy-speak.
$\sum_{i=1}^{m}\sigma^2_i = \sum_{i=1}^{m}\lambda_i$ where $\sigma^2_i$ denote the variance of the $i$th feature and $\lambda_i$ the $i$th eigenvalue of $\Sigma$.
As Ted Shifrin so concisely shows, the trace of a matrix is equal to the sum of its eigenvalues. Thus, as the diagonal of a covariance matrix contains each feature's variance, the total variance is equal to the sum of the eigenvalues.
The variance of $X$ after being projected onto $\vec{v}$ is $\lambda_i$ when $v$ is the $i$th normalized eigenvector of $\Sigma$ (i.e. $i$th principal component).
The projection of $X$ onto $\vec{v}$ is given by $X\vec{v}$ and the covariance matrix of $X$ given by $X^TX/n$. The covariance of $X\vec{v}$ is $(X\vec{v})^TX\vec{v}/n$ is $\vec{v}^TX^TX\vec{v}$ is $\vec{v}^T\Sigma\vec{v}$ is $\vec{v}^T\lambda_i\vec{v}$ is $\lambda_i$ as $\vec{v}$ is normalized ($|\vec{v}|=1=\vec{v}^T\vec{v}$).
Thus, the fraction of the original variance retained by a projection of $X$ onto the $i$th principal component is $\lambda_i/\sum_{i=1}^m\lambda_i$. As a corollary, the fraction retained by a projection onto the top $N$ principal components is $\sum_{i=1}^N\lambda_i/\sum_{i=1}^m\lambda_i$.
The argument from Appendix I is useful for analyzing how well the plots in this essay represent their high-dimensional residual stream data.
Figure 7. The fraction of variance captured (as defined in Appendix I) for varrying numbers of principal components using the math problem dataset from Figure 2. For example, the point corresponding to $3$ on the x-axis denotes the fraction of the original variance captured when the data is projected onto its top three principal components. Past eleven principal components, the trend towards $1$ continues. I've cut off the plot as looking at all $4096$ principal components isn't terribly illuminating.
Across datasets used in The Geometry of Truth: Emergent Linear Structure in Large Language Model Representations of True/False Datasets there is significant variance in how "good" of a fit the first principal components provide. For example, I reproduced the author's experiments testing true/false statements about the location of cities and number comparasons on my LLaMA-7b setup.
Figure 8. Each point corresponds to a true/false statement (full list) about what country a city is in. For example, "The city of Lu'an is in China." Crosses correspond to true statements and dots to false ones. Points are sent through the first twelve layers of LLaMA-7b and projected onto their top two principal components as in Figure 1. Like Figure 3, true and false statements appear well separated.
Figure 9. The same plot as Figure 7, using the data from Figure 8. The top two principal components capture about the same fraction of the original variance as the top principal component of the math problem dataset.
Figure 10. Each point corresponds to a true/false statement (full list) about the relative size of two numbers. For example, "Fifty-one is smaller than fifty-two." Crosses correspond to true statements and dots to false ones. Points are sent through the first twelve layers of LLaMA-7b and projected onto their top two principal components as in Figure 1.
Figure 11. The same plot as Figure 7, using the data from Figure 10. In Figure 10 the data appears much less separated than Figure 8, and this plot confirms only a small fraction of the variance is captured by the first two principal components.
This discussion is relegated to the appendix because I'm not entirely sure what to make of it. On one hand, the top two principal components of the math problem dataset from Figure 2 capturing 56% of the original variance is highly suggestive of a good fit as the residual stream samples are $4096$-dimensional vectors. On the other hand, its not incredible, and the cities and Spanish-to-English translation datasets capturing 45% and 24% of the original variance on their top two principal components respectively makes me cautious about drawing hard conclusions from the data. There may well be something else going on here.
Thanks to Moses, N, Nick, and Noah for invaluable feedback and conversation.