Two months of matrix multiply, of regret, loss, spring, sunburn, sore
muscles, and a return to my bike. This post resumes the
search for ways to multiply matrices with few multiplications. Some
nihilistic hedonism propels me to continue this, but lately I reflect
on how much I miss while absorbed by these things. I describe an
interesting object called the matrix multiply flip graph, in which
many world-record matrix multiply schemes have been found. We'll prove
something new about its connectivity, then consider how to efficiently
explore it using a GPU. The GPU-accelerated search procedure is quite
fast, recovering the existing world-record scheme for 4Γ4
matrix multiply in around thirty seconds on an RTX Pro 6000, compared
to many GPU-weeks in the original
publication. As
of yet though, I have no new world-record schemes to report, hence
this being a blog post.
The vertices of the matrix multiply flip graph are ways to multiply matrices. Some ways require fewer multiplications than others; because multiplication is expensive in hardware, considerable effort has been spent exploring the flip graph to find them.
Concretely, the vertices of the matrix multiply flip graph are rank-R decompositions of the matrix multiply tensor.
Definition 1. The matrix multiply tensor for multiplying two n by n matrices (note 1) is the unique M(n) satisfying, for any n by n matrices D,E,
That is, M(n) has shape (n,n,n,n,n,n) and picks what products of elements of D and E sum to each element of DE.
M(n) is said to have a rank-R decomposition if
M(n)=r=1βRβA(r)βB(r)βC(r),(2)
for matrices A(i),B(i),C(i) with shape (n,n), where (AβB)i,j,k,lβ=Ai,jβBk,lβ denotes the outer product. It is not too difficult to see that DE can be computed with R multiplications iff M(n) has a rank-R decomposition, as after substituting (2) into (1) one arrives at
As the mrβ terms do not depend on the values of k,l, they can be precomputed. The Rmrβ terms are the only multiplications involving elements of D and E required for computing (DE)k,lβ, so DE is said to computable with R multiplies iff M(n) has a rank-R decomposition.
It is natural then to say,
Definition 2. The vertices of the matrix multiply flip graph for multiplication of n by n matrices are all multisets
satisfying βsβSβs=M(n) and having no elements that evaluate to zero.
It would be convenient if S was connected to Sβ² iff β£Sβ²β£<β£Sβ£, but it is, as of yet, not known how to compute the out-neighborhood of a vertex in this graph. Instead, work thus far uses some subset of easily-computable flip, plus, and reduction edges. An interesting property of the edge-choice is if the resulting graph is connected, i.e. from any vertex every other vertex is reachable, as in a such a graph all unknown ways to multiply matricies are reachable from all known ones. In their work introducing the flip graph, Kauers and Moosbauer [1] showed with flip and reduction edges the resulting graph is weakly connected; after treating reduction edges as undirected, the graph is connected. Then, Arai, Yuma, and Hukushima [2] showed that full connectivity was achievable with a third edge, plus.
Despite connectivity proofs being for arbitrary fields, to our knowledge, all existing implementations search for schemes over F2β, as over this field extremely fast search implementations are possible. Hence, this note investigates what edges are required for connectivity over F2β. It is shown that over F2β flip and plus suffice.
Definition 3. If S can be transformed into Sβ² via a flip or plus transform, then there is a flip/plus edge from S to Sβ² in the matrix multiply flip graph.
A flip on the A position transforms S into Sβ² if there exists AβBβC,Aβ²βBβ²βCβ²βS such that A=Aβ² and
Flip and plus transforms are defined likewise to operate on the B and C positions of summands. If a flip transform results in a zero summand, that summand is dropped.
Theorem 4. Over F2β, the matrix multiply flip graph with flip and plus edges is connected.
Connectivity Proof
Lemma 5. *The matrix multiply tensor for multiplying two n by n matrices is,
where Eijβ is a n by n matrix with a 1 at position ij and zeros everywhere else.*
Proof. As a consequence of (1), Mghijkl(n)β=Ξ΄g,kβΞ΄h,iβΞ΄j,lβ where Ξ΄g,hβ denotes the Kronecker delta, so inspecting the basis-decomposition of M(n) we have,
Comparing the coefficient of the basis matrix Ej,kβ²ββEi,kβ on the left and right hand side of the equality, for any i,j when k=kβ² we have Ei,jβ=βrRβA(r)Ξ²j,k(r)βΞ³i,k(r)β, so every Ei,jβ is a linear combination of A(r)'s. A symmetric argument applies to the C and B terms.
Lemma 7. *Let S be a decomposition of M(n) over F2β and AβBβC, Aβ²βBβ²βCβ² be elements of S. There is a path from S to
SA+β=Sβͺ{(A+Aβ²)βBβC,(A+Aβ²)βBβC},
in the matrix multiply flip graph; likewise for SB+β and SC+β.*
Figure 1. Derivation of Lemma 7's SA+β. X denotes A+Aβ², PAβ a plus on the A position of its inputs, and FAβ, likewise, a flip on the A position.
Proof. Figure 1 shows a derivation of SA+β. SB+β and SC+β are symmetric.
Lemma 8. *Let S be a decomposition of M(n) over F2β. For arbitrary, nonzero X,Y,Z, there is a path from S to
Sβͺ{XβYβZ,XβYβZ},
in the matrix multiply flip graph.*
Proof. Fix a summand AβBβCβS. By Lemma 6, the A-components of the summands of S span F2nΓnβ, so there are terms A(1),β¦,A(m) satisfying
X=A+A(1)+β―+A(m).
Thus, Lemma 7 gives a procedure to add two copies of
(A+A(1))βBβC,
then
(A+A(1)+A(2))βBβC,
and so on, until two copies of XβBβC are generated, at which point flips eliminate intermediate terms yielding Sβͺ{XβBβC,XβBβC}. This is then repeated this on the B and C positions.
Lemma 9. Let S be a decomposition of M(n) over F2β. If RβS sums to zero, i.e. βrβRβr=0, then there is a path from S to SβR in the matrix multiply flip graph.
Proof. Let Ei,jβ denote the n by n basis matrix with a 1 at position i,j and zeros everywhere else. First, decompose every element of R into basis matrices as follows: If the A term of AβBβCβR is not already a basis matrix, then A=Ei1β,j1ββ+β―+Eitβ,jtββ, so add Ei1β,j1βββBβC twice via Lemma 8
Repeating this for each element of A's decomposition and the B and C position decomposes AβBβCβR into its basis matrices.
After decomposing R the resulting scheme is (SβR)βͺRβ², where Rβ² is a multiset of basis-matrix outer products. Notice that the basis-matrix outer products are linearly independent, as Ei,jββEk,lββEm,nβ is itself a basis tensor with a one at position (i,j,k,l,m,n) and zeros everywhere else. Thus, for βrβRβ²βr to be zero, every basis-matrix outer product in Rβ² must appear an even number of times (i.e. its coefficient must be 0 over F2β). Hence, as over F2β a flip removes both inputs if they are the same, applying flips to identical elements removes Rβ², giving SβR, as desired.
Proof of Theorem 4. To travel from any scheme S to any scheme Sβ², first transition to SβͺSβ²βͺSβ² by adding every element of Sβ² to S twice via Lemma 8. Then, as βsβSβs=βsβSβ²βs=M(n), βsβSβͺSβ²βs=0, so R=SβͺSβ² is a subset of SβͺSβ²βͺSβ² that sums to zero and can be removed per Lemma 9, leaving Sβ², as desired.
GPU-Accelerated Search
We now, somewhat unceremoniously, turn our focus towards how to efficently search the flip graph. As of yet, no better way to do this than random walks is known.
The most performance-sensitive part of random walks on the flip graph is identifying flip opportunities, i.e. summands that match in the A, B, or C position.
AβTermΒ atΒ BΒ positionBβββCβSummandβ
Over F2β an 8Γ8 term fits in a single 64-bit word, so existing flip graph search procedures for small matrix sizes maintain, for each position A,B,C, a map that stores the indices of summands with a particular value at that value. For example, if summands at indices 1,2,3 have the value x for their A a term and mAβ is the map for A terms, mAβ(x)=[1,2,3]. For very small matrix sizes, 4Γ4 or less, this map can be implemented with an array by interpreting term's bit-level representation as an index and larger sizes require hashing. However, this approach is not GPU friendly as it is branchy, requires operating on sizable buffers in shared memory, and, generally, is unclear how to adapt to the single-instruction-multiple-threads model used by GPUs (note 2).
Fortunately, recent CUDA GPUs have a primitive, __match_any_sync, that returns a bitmask of threads within a warp that have the same value for a given variable.
When there are fewer summands than threads in a warp this can be used as follows: Each thread stores a single summand. At each step all threads agree on a term position (A, B, or C) and execute __match_any_sync on their summand's value at that position. Ones in the resulting bitmasks correspond to flip opportunities. This means flips are identified with a hardware primitive instead of a hash table.
In practice, available GPUs have 32 threads per warp. As world-record decompositions of M(4) and M(5) involve 47 and 93 summands respectively, multiple summands must be stored per thread. Flip search then becomes probabilistic: At each step warps agree on a term position then, for k rounds, each thread picks a random summand and calls __match_any_sync. Increasing k increases the likelihood that available flip opportunities are found.
Generalized Flips
With __match_any_sync it is zero-cost to identify multiple summands that share a term, as this corresponds to more than two bits being set in the returned bitmask. Flip semantics can be generalized to operate on multiple summands and used to look ahead for rank-reductions.
Lets sketch a flip opportunity as matrix multiplication over subterms with β as multiply.
A consequence of this framing is, if the B(i) or C(i) terms are not lineally independent, then appropriate choice of G results in a zero in the B or C vectors, and a zero corresponds to a rank-reduction. Thus appropriate choice of G can skip ahead in the search space to find rank-reductions that may require many pairwise flip and plus transforms to set up the needed linear combination.
It remains to discuss how to efficiently determine if B(1),β¦,B(N) are linealy dependent on GPU. First, it seems to be the case that sets of summands sharing a term are overwhelmingly of size β€5 on the 5Γ5 flip graph: in 3,273 schemes sampled from random-walks there were 43,386 sets of more than two summands sharing a term and 99.93% had size β€5; Figure 2 visualizes this. Over F2β determining if B(1),β¦,B(5) is lineally dependent reduces to determining if the sum (i.e. XOR, i.e. β) of any subset is zero. So, 5 is a useful number, because as sets of size 5 have 31 non-empty subsets and warps have 32 threads, we can assign each subset to a thread.
Figure 2. Distribution of sizes of sets of summands sharing a term from 3,273 schemes sampled from random-waks on the 5Γ5 flip graph.
More precisely, let s be the index of a thread in a warp and Dsβ be the indices of the set bits of s. Thread zero is idle while threads sξ =0 compute
iβDsββ¨βB(i)=0.
A warp ballot then finds the lowest-index thread for which the above is true. If one exists, call it thread m and the index of its first set bit r=minDmβ, threads rewrite their summands from
The exposition is on square matrix multiplication but is easily generalized.
At a the necessary level of abstraction for this section, GPUs are composed of many processors composed of threads that all execute the same instruction at the same time. For example, two threads on a processor can simultaneously compute x+y, but one cannot compute x+y while the other computes xβy, as + and β are different instructions. Branchy code performs poorly on GPUs because one thread cannot execute the false branch of an if statement while the other executes the true branch. Some primitives, like __match_any_sync, allow threads in a warp to communicate with one another. If you'd like to understand GPU in more detail I made a considerable effort to explain in this post.