Snapshots of Incomplete Thoughts

Large Language Models process input with an ordered sequence of layers. Let li(x)l_i(x) denote the result of calling the ithith layer on some input xx; The value returned by a model with three layers is denoted M(x)=l3(l2(l1(x)))M(x) = l_3(l_2(l_1(x))).

In this essay, I analyze the output of l13l_{13} (of 3232) of Meta's seven billion parameter LLM (LLaMA-7b) using techniques from The Geometry of Truth: Emergent Linear Structure in Large Language Model Representations of True/False Datasets by Samuel Marks and Max Tegmark.

I curate (code) datasets of questions about math problems (for example, "What is 2/2+62/2+6?"). I send those statements through LLaMA-7b and sample the output of l13l_{13}. This output is high-dimensional and difficult to visualize, so I perform principal component analysis on the samples to find a two-dimensional representation of l13l_{13}'s output.

In this two-dimensional representation LLaMA-7b groups math problems according to their ordering of operations.

Figure 1. Each point corresponds to a question (full list) in the form "What is \circ", where \circ is a math problem in the form a+bca+b-c or ab+ca-b+c resulting in either 44 or 77. Each question is sent through LLaMA-7b and the output of l13l_{13} is sampled. PCA is then performed on the samples, and each sample is projected onto the top two principal components to create this plot. The horizontal axis corresponds to the first principal component, and the vertical one to the second.

There is further separation for problems that begin with a negative number. And mysteriously, the top right cluster in the plot above corresponds to problems ending in +1+1. For example, "What is (93)+1(9-3)+1?".

In Anthropic's research on Transformer mathematics, they describe how "layers can send different information to different layers by storing it in different subspaces [of the residual stream]." Is that what we see here in the separation of math problems by their order of operations? Do future layers know to look in these subspaces and how to evaluate problems within?

In the plot below, I expand the earlier plot to include a larger variety of math problem types. In the legend, 12\circ_1\circ_2 is shorthand for a problem with the structure a1b2ca \circ_1 b \circ_2 c. For example (1+2)/3(1+2)/3 is classified as +/.

Figure 2. Each point corresponds to a question (full list) in the form "What is \circ", where \circ is a math problem in the form a1b2ca \circ_1 b \circ_2 c labeled 12\circ_1\circ_2 resulting in either 44 or 77. For example, a point labeled ++- corresponds to a problem in the form a+bca+b-c. Two-dimensional representations of questions are found using the methods described in Figure 1.

As Raul Molina writes in Traveling Words: A Geometric Interpretation of Transformers, "layer normalization confines the latent features to a hyper-sphere". Is the circular shape of the plot due to that hypersphere?

There's deep mystery here. Consider what happens when reproducing Samuel Marks and Max Tegmark's experiment, by inputting true and false statements about Spanish to English translation and sampling the output of layer 12. An example true statement in the below plot is "The Spanish word 'estudiante' means 'student." An example false one, "The Spanish word 'dos' means 'enemy'."

Figure 3. Each point corresponds to a statement about Spanish-to-English translation (full list) in the form "The Spanish word '1\circ_1' means '2\circ_2'." 1\circ_1 is a Spanish word, and 2\circ_2 an English word or phrase. Points are labeled according to if 2\circ_2 is 1\circ_1's English translation. Two-dimensional representations of statements are found using the same method as Figure 1.

In the case of math problems, the model separates statements based on the structure. Here, it separates statements based on whether they are true or false! Does the model have an encoding of capital-T Truth? Unlikely, but how then is it performing this separation? What structure is found in true/false statements?

Figure 4. The same data and representation as Figure 2, but labeled according to the result of the math problem. Evidently, LLaMA-7b does not separate math problems based on their result.

Why are true and false statements separated by their result, while math problems are separated by their structure?

As a final confounding point, consider the result of running the same experiment on the second-to-last layer of a small neural network trained to classify handwritten digits (code).

When images corresponding to the handwritten digits four and seven (or any arbitrary pair of digits) are used to create the samples, there is clear separation.

Figure 5. Each point corresponds to an image of a handwritten four or seven from the MNIST dataset. Points are sent through a two-layer neural network trained to classify them, and projected onto their top two principal components as in Figure 1.

Figure 6. Each point corresponds to an image of a handwritten digit from the MNIST dataset. Points are sent through a two-layer neural network trained to classify them, and projected onto their top two principal components as in Figure 1. As with math problem structure, the separation becomes less clear when considering additional classifications, though hints remain.

This handwritten digit classifier is much less complex than LLaMA-7b. Yet, like true/false statements, it separates handwritten digits based on their classification when projected onto the principal components. Does this mean separation on the principal components is something fundamental to the operation of neural networks?

There is so much mystery here! And it is fascinating to glimpse how LLaMA-7b thinks about math problems. Inside LLaMA's representations of human concepts, I suspect there are deep truths about the ontology of human speech and knowledge.

Appendix I - Residuals

Claim: The fraction of the variance retained by a mean-centered, n×mn\times m matrix XX containing nn samples with mm features projected onto its iith principal component is

λii=1mλi\frac{\lambda_i}{\sum_{i=1}^m\lambda_i}

where λi\lambda_i is the iith eigenvalue (of mm) of the covariance matrix of XX.

  1. The principal components of XX are the eigenvectors of its covariance matrix, Σ\Sigma, ordered by their corresponding eigenvalues.

    For example, the first principal component corresponds to the eigenvector of Σ\Sigma with the largest eigenvalue. I like Cosma Shalizi's proof of this (Section 18.1). By mean-centered I mean X - np.mean(X,axis=0,keepdims=True) in numpy-speak.

  2. i=1mσi2=i=1mλi\sum_{i=1}^{m}\sigma^2_i = \sum_{i=1}^{m}\lambda_i where σi2\sigma^2_i denote the variance of the iith feature and λi\lambda_i the iith eigenvalue of Σ\Sigma.

    As Ted Shifrin so concisely shows, the trace of a matrix is equal to the sum of its eigenvalues. Thus, as the diagonal of a covariance matrix contains each feature's variance, the total variance is equal to the sum of the eigenvalues.

  3. The variance of XX after being projected onto v\vec{v} is λi\lambda_i when vv is the iith normalized eigenvector of Σ\Sigma (i.e. iith principal component).

    The projection of XX onto v\vec{v} is given by XvX\vec{v} and the covariance matrix of XX given by XTX/nX^TX/n. The covariance of XvX\vec{v} is (Xv)TXv/n(X\vec{v})^TX\vec{v}/n is vTXTXv\vec{v}^TX^TX\vec{v} is vTΣv\vec{v}^T\Sigma\vec{v} is vTλiv\vec{v}^T\lambda_i\vec{v} is λi\lambda_i as v\vec{v} is normalized (v=1=vTv|\vec{v}|=1=\vec{v}^T\vec{v}).

  4. Thus, the fraction of the original variance retained by a projection of XX onto the iith principal component is λi/i=1mλi\lambda_i/\sum_{i=1}^m\lambda_i. As a corollary, the fraction retained by a projection onto the top NN principal components is i=1Nλi/i=1mλi\sum_{i=1}^N\lambda_i/\sum_{i=1}^m\lambda_i.

Appendix II - PCA Fit

The argument from Appendix I is useful for analyzing how well the plots in this essay represent their high-dimensional residual stream data.

Figure 7. The fraction of variance captured (as defined in Appendix I) for varrying numbers of principal components using the math problem dataset from Figure 2. For example, the point corresponding to 33 on the x-axis denotes the fraction of the original variance captured when the data is projected onto its top three principal components. Past eleven principal components, the trend towards 11 continues. I've cut off the plot as looking at all 40964096 principal components isn't terribly illuminating.

Across datasets used in The Geometry of Truth: Emergent Linear Structure in Large Language Model Representations of True/False Datasets there is significant variance in how "good" of a fit the first principal components provide. For example, I reproduced the author's experiments testing true/false statements about the location of cities and number comparasons on my LLaMA-7b setup.

Figure 8. Each point corresponds to a true/false statement (full list) about what country a city is in. For example, "The city of Lu'an is in China." Crosses correspond to true statements and dots to false ones. Points are sent through the first thirteen layers of LLaMA-7b and projected onto their top two principal components as in Figure 1. Like Figure 3, true and false statements appear well separated.

Figure 9. The same plot as Figure 7, using the data from Figure 8. The top two principal components capture about the same fraction of the original variance as the top principal component of the math problem dataset.

Figure 10. Each point corresponds to a true/false statement (full list) about the relative size of two numbers. For example, "Fifty-one is smaller than fifty-two." Crosses correspond to true statements and dots to false ones. Points are sent through the first thirteen layers of LLaMA-7b and projected onto their top two principal components as in Figure 1.

Figure 11. The same plot as Figure 7, using the data from Figure 10. In Figure 10 the data appears much less separated than Figure 8, and this plot confirms only a small fraction of the variance is captured by the first two principal components.

This discussion is relegated to the appendix because I'm not entirely sure what to make of it. On one hand, the top two principal components of the math problem dataset from Figure 2 capturing 56% of the original variance is highly suggestive of a good fit as the residual stream samples are 40964096-dimensional vectors. On the other hand, its not incredible, and the cities and Spanish-to-English translation datasets capturing 45% and 24% of the original variance on their top two principal components respectively makes me cautious about drawing hard conclusions from the data. There may well be something else going on here.


Thanks to Moses, N, Nick, and Noah for invaluable feedback and conversation.