The Matrix Multiply Flip Graph

Two months of matrix multiply, of regret, loss, spring, sunburn, sore muscles, and a return to my bike. This post resumes the search for ways to multiply matrices with few multiplications. Some nihilistic hedonism propels me to continue this, but lately I reflect on how much I miss while absorbed by these things. I describe an interesting object called the matrix multiply flip graph, in which many world-record matrix multiply schemes have been found. We'll prove something new about its connectivity, then consider how to efficiently explore it using a GPU. The GPU-accelerated search procedure is quite fast, recovering the existing world-record scheme for 4Γ—44\times 4 matrix multiply in around thirty seconds on an RTX Pro 6000, compared to many GPU-weeks in the original publication. As of yet though, I have no new world-record schemes to report, hence this being a blog post.

The vertices of the matrix multiply flip graph are ways to multiply matrices. Some ways require fewer multiplications than others; because multiplication is expensive in hardware, considerable effort has been spent exploring the flip graph to find them.

Concretely, the vertices of the matrix multiply flip graph are rank-RR decompositions of the matrix multiply tensor.

Definition 1. The matrix multiply tensor for multiplying two nn by nn matrices (note 1) is the unique M(n)M^{(n)} satisfying, for any nn by nn matrices D,ED,E,

(DE)k,l=βˆ‘g,h,i,j=1nMg,h,i,j,k,l(n)Dg,hEi,j.(1)\tag{1} (DE)_{k,l} = \sum_{g,h,i,j=1}^{n}M^{(n)}_{g,h,i,j,k,l}D_{g,h}E_{i,j}.

That is, M(n)M^{(n)} has shape (n,n,n,n,n,n)(n,n,n,n,n,n) and picks what products of elements of DD and EE sum to each element of DEDE.

M(n)M^{(n)} is said to have a rank-RR decomposition if

M(n)=βˆ‘r=1RA(r)βŠ—B(r)βŠ—C(r),(2)\tag{2} M^{(n)}=\sum_{r=1}^R A^{(r)}\otimes B^{(r)}\otimes C^{(r)},

for matrices A(i),B(i),C(i)A^{(i)},B^{(i)},C^{(i)} with shape (n,n)(n,n), where (AβŠ—B)i,j,k,l=Ai,jBk,l(A\otimes B)_{i,j,k,l}=A_{i,j}B_{k,l} denotes the outer product. It is not too difficult to see that DEDE can be computed with RR multiplications iff M(n)M^{(n)} has a rank-RR decomposition, as after substituting (2) into (1) one arrives at

(DE)k,l=βˆ‘g,h,i,j(βˆ‘rA(r)βŠ—B(r)βŠ—C(r))g,h,i,j,k,lDg,hEi,j=βˆ‘g,h,i,jβˆ‘rAg,h(r)Dg,hBi,j(r)Ei,jCk,l(r)=βˆ‘rCk,l(r)βˆ‘g,h,i,jAg,h(r)Dg,hBi,j(r)Ei,j=βˆ‘rCk,l(r)(βˆ‘g,hAg,h(r)Dg,h)(βˆ‘i,jBi,j(r)Ei,j)⏟mr.\begin{aligned} &(DE)_{k,l} \\ &= \sum_{g,h,i,j} (\sum_rA^{(r)}\otimes B^{(r)}\otimes C^{(r)})_{g,h,i,j,k,l}D_{g,h}E_{i,j} \\ &= \sum_{g,h,i,j} \sum_rA^{(r)}_{g,h}D_{g,h} B^{(r)}_{i,j}E_{i,j} C^{(r)}_{k,l} \\ &= \sum_rC^{(r)}_{k,l}\sum_{g,h,i,j}A^{(r)}_{g,h}D_{g,h} B^{(r)}_{i,j}E_{i,j} \\ &= \sum_rC^{(r)}_{k,l}\underbrace{(\sum_{g,h}A^{(r)}_{g,h}D_{g,h})(\sum_{i,j}B^{(r)}_{i,j}E_{i,j})}_{m_r}. \end{aligned}

As the mrm_r terms do not depend on the values of k,lk,l, they can be precomputed. The RR mrm_r terms are the only multiplications involving elements of DD and EE required for computing (DE)k,l(DE)_{k,l}, so DEDE is said to computable with RR multiplies iff M(n)M^{(n)} has a rank-RR decomposition.

It is natural then to say,

Definition 2. The vertices of the matrix multiply flip graph for multiplication of nn by nn matrices are all multisets

S={A(1)βŠ—B(1)βŠ—C(1),…,A(∣S∣)βŠ—B(∣S∣)βŠ—C(∣S∣)}S=\{A^{(1)}\otimes B^{(1)}\otimes C^{(1)},\dots,A^{(|S|)}\otimes B^{(|S|)}\otimes C^{(|S|)}\}

satisfying βˆ‘s∈Ss=M(n)\sum_{s\in S}s = M^{(n)} and having no elements that evaluate to zero.

It would be convenient if SS was connected to Sβ€²S' iff ∣Sβ€²βˆ£<∣S∣|S'|<|S|, but it is, as of yet, not known how to compute the out-neighborhood of a vertex in this graph. Instead, work thus far uses some subset of easily-computable flip, plus, and reduction edges. An interesting property of the edge-choice is if the resulting graph is connected, i.e. from any vertex every other vertex is reachable, as in a such a graph all unknown ways to multiply matricies are reachable from all known ones. In their work introducing the flip graph, Kauers and Moosbauer [1] showed with flip and reduction edges the resulting graph is weakly connected; after treating reduction edges as undirected, the graph is connected. Then, Arai, Yuma, and Hukushima [2] showed that full connectivity was achievable with a third edge, plus.

Despite connectivity proofs being for arbitrary fields, to our knowledge, all existing implementations search for schemes over F2\mathbb{F}_2, as over this field extremely fast search implementations are possible. Hence, this note investigates what edges are required for connectivity over F2\mathbb{F}_2. It is shown that over F2\mathbb{F}_2 flip and plus suffice.

Definition 3. If SS can be transformed into Sβ€²S' via a flip or plus transform, then there is a flip/plus edge from SS to Sβ€²S' in the matrix multiply flip graph.

A flip on the AA position transforms SS into Sβ€²S' if there exists AβŠ—BβŠ—C,Aβ€²βŠ—Bβ€²βŠ—Cβ€²βˆˆSA\otimes B\otimes C,A'\otimes B'\otimes C'\in S such that A=Aβ€²A=A' and

Sβ€²=Sβˆ–{AβŠ—BβŠ—C,Aβ€²βŠ—Bβ€²βŠ—Cβ€²}βˆͺ{AβŠ—BβŠ—(C+Cβ€²),Aβ€²βŠ—(Bβ€²βˆ’B)βŠ—Cβ€²}.\begin{aligned} S'=S&\setminus\{A\otimes B\otimes C,A'\otimes B'\otimes C'\}\\&\cup \{A\otimes B\otimes (C+C'),A'\otimes (B'-B)\otimes C'\}. \end{aligned}

A plus on the AA position transforms SS into Sβ€²S' if, there exists AβŠ—BβŠ—C,Aβ€²βŠ—Bβ€²βŠ—Cβ€²βˆˆSA\otimes B\otimes C,A'\otimes B'\otimes C'\in S,

Sβ€²=Sβˆ–{Aβ€²βŠ—Bβ€²βŠ—Cβ€²}βˆͺ{(Aβ€²βˆ’A)βŠ—Bβ€²βŠ—Cβ€²,AβŠ—Bβ€²βŠ—Cβ€²},\begin{aligned} S'=S&\setminus\{A'\otimes B'\otimes C'\}\\&\cup\{(A'-A)\otimes B'\otimes C', A\otimes B'\otimes C'\}, \end{aligned}

Flip and plus transforms are defined likewise to operate on the BB and CC positions of summands. If a flip transform results in a zero summand, that summand is dropped.

Theorem 4. Over F2\mathbb{F}_2, the matrix multiply flip graph with flip and plus edges is connected.

Connectivity Proof

Lemma 5. *The matrix multiply tensor for multiplying two nn by nn matrices is,

M(n)=βˆ‘i=1nβˆ‘j=1nβˆ‘k=1nEijβŠ—EjkβŠ—EikM^{(n)}=\sum_{i=1}^n \sum_{j=1}^n \sum_{k=1}^n E_{ij} \otimes E_{jk} \otimes E_{ik}

where EijE_{ij} is a nn by nn matrix with a 11 at position ijij and zeros everywhere else.*

Proof. As a consequence of (1), Mghijkl(n)=Ξ΄g,kΞ΄h,iΞ΄j,lM^{(n)}_{ghijkl}=\delta_{g,k}\delta_{h,i}\delta_{j,l} where Ξ΄g,h\delta_{g,h} denotes the Kronecker delta, so inspecting the basis-decomposition of M(n)M^{(n)} we have,

M(n)=βˆ‘i,j,jβ€²,k,iβ€²,kβ€²Ξ΄i,iβ€²Ξ΄k,kβ€²Ξ΄j,jβ€²Ei,jβŠ—Ejβ€²,kβŠ—Eiβ€²,kβ€²=βˆ‘i,k,jEi,jβŠ—Ej,kβŠ—Ei,k.\begin{aligned} M^{(n)} &= \sum_{i,j,j',k,i',k'}\delta_{i,i'}\delta_{k,k'}\delta_{j,j'}E_{i,j}\otimes E_{j',k} \otimes E_{i',k'} \\ &= \sum_{i,k,j}E_{i,j}\otimes E_{j,k}\otimes E_{i,k}. \end{aligned}

Lemma 6. *If S={A(1)βŠ—B(1)βŠ—C(1),… }S=\{A^{(1)}\otimes B^{(1)}\otimes C^{(1)},\dots\} is a decomposition of M(n)M^{(n)}, then

span(A(1),…,A(∣S∣))=span(B(1),…,B(∣S∣))=span(C(1),…,C(∣S∣))=FnΓ—n.\begin{aligned} &\text{span}(A^{(1)},\dots,A^{(|S|)})\\&=\text{span}(B^{(1)},\dots,B^{(|S|)})\\&=\text{span}(C^{(1)},\dots,C^{(|S|)})=\mathbb{F}^{n\times n}. \end{aligned}

Proof. Write each B(r)B^{(r)} and C(r)C^{(r)} in the standard basis

B(r)=βˆ‘j,kβ€²m,pΞ²j,kβ€²(r)Ej,kβ€²,C(r)=βˆ‘i,kn,pΞ³i,k(r)Ei,k.B^{(r)}=\sum_{j,k'}^{m,p}\beta^{(r)}_{j,k'}E_{j,k'}, \quad C^{(r)}=\sum_{i,k}^{n,p}\gamma^{(r)}_{i,k}E_{i,k}.

By the definition of SS and the basis-form of M(n)M_{(n)}, we have

βˆ‘i,j,kn,m,pEi,jβŠ—Ej,kβŠ—Ei,k=βˆ‘rRA(r)βŠ—B(r)βŠ—C(r)=βˆ‘rRA(r)βŠ—(βˆ‘j,kβ€²m,pΞ²j,kβ€²(r)Ej,kβ€²)βŠ—(βˆ‘i,kn,pΞ³i,k(r)Ei,k)=βˆ‘rRβˆ‘j,kβ€²,i,km,p,n,p(A(r)Ξ²j,kβ€²(r)Ξ³i,k(r))βŠ—Ej,kβ€²βŠ—Ei,k=βˆ‘j,kβ€²,i,km,p,n,p(βˆ‘rRA(r)Ξ²j,kβ€²(r)Ξ³i,k(r))βŠ—Ej,kβ€²βŠ—Ei,k.\begin{aligned} &\sum_{i,j,k}^{n,m,p}E_{i,j}\otimes E_{j,k} \otimes E_{i,k}\\ &= \sum_r^RA^{(r)}\otimes B^{(r)}\otimes C^{(r)} \\ &= \sum_{r}^RA^{(r)}\otimes(\sum_{j,k'}^{m,p}\beta^{(r)}_{j,k'}E_{j,k'})\otimes (\sum_{i,k}^{n,p}\gamma^{(r)}_{i,k}E_{i,k}) \\ &= \sum_{r}^R\sum_{j,k',i,k}^{m,p,n,p}(A^{(r)}\beta^{(r)}_{j,k'}\gamma^{(r)}_{i,k})\otimes E_{j,k'} \otimes E_{i,k} \\ &= \sum_{j,k',i,k}^{m,p,n,p}(\sum_{r}^{R}A^{(r)}\beta^{(r)}_{j,k'}\gamma^{(r)}_{i,k})\otimes E_{j,k'} \otimes E_{i,k}. \end{aligned}

Comparing the coefficient of the basis matrix Ej,kβ€²βŠ—Ei,kE_{j,k'}\otimes E_{i,k} on the left and right hand side of the equality, for any i,ji,j when k=kβ€²k=k' we have Ei,j=βˆ‘rRA(r)Ξ²j,k(r)Ξ³i,k(r)E_{i,j}=\sum_{r}^{R}A^{(r)}\beta^{(r)}_{j,k}\gamma^{(r)}_{i,k}, so every Ei,jE_{i,j} is a linear combination of A(r)A^{(r)}'s. A symmetric argument applies to the CC and BB terms.

Lemma 7. *Let SS be a decomposition of M(n)M^{(n)} over F2\mathbb{F}_2 and AβŠ—BβŠ—CA\otimes B\otimes C, Aβ€²βŠ—Bβ€²βŠ—Cβ€²A'\otimes B'\otimes C' be elements of SS. There is a path from SS to

SA+=Sβˆͺ{(A+Aβ€²)βŠ—BβŠ—C,β€…β€Š(A+Aβ€²)βŠ—BβŠ—C},S_{A+} = S \cup \{ (A+A')\otimes B\otimes C,\; (A+A')\otimes B\otimes C \},

in the matrix multiply flip graph; likewise for SB+S_{B+} and SC+S_{C+}.*

Figure 1. Derivation of Lemma 7's SA+S_{A+}. XX denotes A+Aβ€²A+A', PAP_A a plus on the AA position of its inputs, and FAF_A, likewise, a flip on the AA position.

Proof. Figure 1 shows a derivation of SA+S_{A+}. SB+S_{B+} and SC+S_{C+} are symmetric.

Lemma 8. *Let SS be a decomposition of M(n)M^{(n)} over F2\mathbb{F}_2. For arbitrary, nonzero X,Y,ZX,Y,Z, there is a path from SS to

Sβˆͺ{XβŠ—YβŠ—Z,β€…β€ŠXβŠ—YβŠ—Z},S \cup \{ X\otimes Y\otimes Z,\; X\otimes Y\otimes Z \},

in the matrix multiply flip graph.*

Proof. Fix a summand AβŠ—BβŠ—C∈S.A\otimes B\otimes C \in S. By Lemma 6, the AA-components of the summands of SS span F2nΓ—n\mathbb{F}_2^{n\times n}, so there are terms A(1),…,A(m)A^{(1)},\dots,A^{(m)} satisfying

X=A+A(1)+β‹―+A(m).X = A + A^{(1)}+\cdots+A^{(m)}.

Thus, Lemma 7 gives a procedure to add two copies of

(A+A(1))βŠ—BβŠ—C,(A+A^{(1)})\otimes B\otimes C,

then

(A+A(1)+A(2))βŠ—BβŠ—C,(A+A^{(1)}+A^{(2)})\otimes B\otimes C,

and so on, until two copies of XβŠ—BβŠ—CX\otimes B\otimes C are generated, at which point flips eliminate intermediate terms yielding Sβˆͺ{XβŠ—BβŠ—C,XβŠ—BβŠ—C}S\cup\{X\otimes B\otimes C,X\otimes B\otimes C\}. This is then repeated this on the BB and CC positions.

Lemma 9. Let SS be a decomposition of M(n)M^{(n)} over F2\mathbb{F}_2. If RβŠ†SR\subseteq S sums to zero, i.e. βˆ‘r∈Rr=0\sum_{r\in R}r=0, then there is a path from SS to Sβˆ–RS\setminus R in the matrix multiply flip graph.

Proof. Let Ei,jE_{i,j} denote the nn by nn basis matrix with a 11 at position i,ji,j and zeros everywhere else. First, decompose every element of RR into basis matrices as follows: If the AA term of AβŠ—BβŠ—C∈RA\otimes B\otimes C\in R is not already a basis matrix, then A=Ei1,j1+β‹―+Eit,jtA=E_{i_1,j_1}+\dots+E_{i_t,j_t}, so add Ei1,j1βŠ—BβŠ—CE_{i_1,j_1}\otimes B\otimes C twice via Lemma 8

AβŠ—BβŠ—C,Ei1,j1βŠ—BβŠ—C,Ei1,j1βŠ—BβŠ—CA\otimes B\otimes C, \quad E_{i_1,j_1}\otimes B\otimes C, \quad E_{i_1,j_1}\otimes B\otimes C

then do a flip on the BB position of the first and second terms above for

(Aβˆ’Ei1,j1)βŠ—BβŠ—C,Ei1,j1βŠ—BβŠ—C.(A-E_{i_1,j_1})\otimes B\otimes C, \quad E_{i_1,j_1}\otimes B\otimes C.

Repeating this for each element of AA's decomposition and the BB and CC position decomposes AβŠ—BβŠ—C∈RA\otimes B\otimes C\in R into its basis matrices.

After decomposing RR the resulting scheme is (Sβˆ–R)βˆͺRβ€²(S\setminus R)\cup R', where Rβ€²R' is a multiset of basis-matrix outer products. Notice that the basis-matrix outer products are linearly independent, as Ei,jβŠ—Ek,lβŠ—Em,nE_{i,j}\otimes E_{k,l}\otimes E_{m,n} is itself a basis tensor with a one at position (i,j,k,l,m,n)(i,j,k,l,m,n) and zeros everywhere else. Thus, for βˆ‘r∈Rβ€²r\sum_{r\in R'}r to be zero, every basis-matrix outer product in Rβ€²R' must appear an even number of times (i.e. its coefficient must be 00 over F2\mathbb{F}_2). Hence, as over F2\mathbb{F}_2 a flip removes both inputs if they are the same, applying flips to identical elements removes Rβ€²R', giving Sβˆ–RS\setminus R, as desired.

Proof of Theorem 4. To travel from any scheme SS to any scheme Sβ€²S', first transition to SβˆͺSβ€²βˆͺSβ€²S\cup S' \cup S' by adding every element of Sβ€²S' to SS twice via Lemma 8. Then, as βˆ‘s∈Ss=βˆ‘s∈Sβ€²s=M(n)\sum_{s\in S}s=\sum_{s\in S'}s=M^{(n)}, βˆ‘s∈SβˆͺSβ€²s=0\sum_{s\in S\cup S'}s=0, so R=SβˆͺSβ€²R=S\cup S' is a subset of SβˆͺSβ€²βˆͺSβ€²S\cup S'\cup S' that sums to zero and can be removed per Lemma 9, leaving Sβ€²S', as desired.

GPU-Accelerated Search

We now, somewhat unceremoniously, turn our focus towards how to efficently search the flip graph. As of yet, no better way to do this than random walks is known.

The most performance-sensitive part of random walks on the flip graph is identifying flip opportunities, i.e. summands that match in the AA, BB, or CC position.

AβŠ—B⏟TermΒ atΒ BΒ positionβŠ—C⏞Summand\overbrace{A\hspace{0.4em}\otimes \underbrace{B}_{\mathclap{\text{Term at $B$ position}}}\otimes\hspace{0.4em} C}^{\text{Summand}}

Over F2\mathbb{F}_2 an 8Γ—88\times 8 term fits in a single 6464-bit word, so existing flip graph search procedures for small matrix sizes maintain, for each position A,B,CA,B,C, a map that stores the indices of summands with a particular value at that value. For example, if summands at indices 1,2,31,2,3 have the value xx for their AA a term and mAm_A is the map for AA terms, mA(x)=[1,2,3]m_A(x)=[1,2,3]. For very small matrix sizes, 4Γ—44\times 4 or less, this map can be implemented with an array by interpreting term's bit-level representation as an index and larger sizes require hashing. However, this approach is not GPU friendly as it is branchy, requires operating on sizable buffers in shared memory, and, generally, is unclear how to adapt to the single-instruction-multiple-threads model used by GPUs (note 2).

Fortunately, recent CUDA GPUs have a primitive, __match_any_sync, that returns a bitmask of threads within a warp that have the same value for a given variable.

__match_any_sync(x)i=1↔ThreadΒ iΒ calledΒ __match_any_syncΒ onΒ x\begin{aligned} &\texttt{\_\_match\_any\_sync(}x\texttt{)}_i=1 \\&\quad\leftrightarrow \text{Thread $i$ called }\texttt{\_\_match\_any\_sync}\text{ on $x$} \end{aligned}

When there are fewer summands than threads in a warp this can be used as follows: Each thread stores a single summand. At each step all threads agree on a term position (AA, BB, or CC) and execute __match_any_sync on their summand's value at that position. Ones in the resulting bitmasks correspond to flip opportunities. This means flips are identified with a hardware primitive instead of a hash table.

In practice, available GPUs have 3232 threads per warp. As world-record decompositions of M(4)M^{(4)} and M(5)M^{(5)} involve 4747 and 9393 summands respectively, multiple summands must be stored per thread. Flip search then becomes probabilistic: At each step warps agree on a term position then, for kk rounds, each thread picks a random summand and calls __match_any_sync. Increasing kk increases the likelihood that available flip opportunities are found.

Generalized Flips

With __match_any_sync it is zero-cost to identify multiple summands that share a term, as this corresponds to more than two bits being set in the returned bitmask. Flip semantics can be generalized to operate on multiple summands and used to look ahead for rank-reductions.

Lets sketch a flip opportunity as matrix multiplication over subterms with βŠ—\otimes as multiply.

AβŠ—(B(1)βŠ—C(1)+B(2)βŠ—C(2))=AβŠ—([B(1)B(2)]⏟B[C(1)C(2)]⏟C)\begin{aligned} &A\otimes (B^{(1)}\otimes C^{(1)} + B^{(2)}\otimes C^{(2)}) \\&= A\otimes \bigl(\underbrace{\begin{bmatrix} B^{(1)} & B^{(2)} \end{bmatrix}}_{B}\underbrace{\begin{bmatrix} C^{(1)} \\ C^{(2)} \end{bmatrix}}_{C}\bigr) \end{aligned}

From this perspective, applying a flip has the same effect as transforming ABAB, above, into AGGβˆ’1BAGG^{-1}B for a particular GG.

AβŠ—([B(1)B(2)][10βˆ’11]⏞G[1011]⏞Gβˆ’1[C(1)C(2)])=AβŠ—(B(1)βˆ’B(2))βŠ—C(1)+AβŠ—B(2)(C(1)+C(2))\begin{aligned} &A\otimes \bigl(\begin{bmatrix} B^{(1)} & B^{(2)} \end{bmatrix} \overbrace{\begin{bmatrix} 1 & 0 \\ -1 & 1 \end{bmatrix}}^{G}\overbrace{\begin{bmatrix} 1 & 0 \\ 1 & 1 \end{bmatrix}}^{G^{-1}}\begin{bmatrix} C^{(1)} \\ C^{(2)} \end{bmatrix}\bigr) \\ &= A\otimes(B^{(1)}-B^{(2)})\otimes C^{(1)}+A\otimes B^{(2)}(C^{(1)}+C^{(2)}) \end{aligned}

This is easily generalized to multiple summands with a shared term and arbitrary, invertible GG.

AβŠ—(βˆ‘i=1NB(i)βŠ—C(i))=AβŠ—([B(1)…B(N)]GGβˆ’1[C(1)…C(N)])\begin{aligned} &A\otimes (\sum_{i=1}^N B^{(i)}\otimes C^{(i)}) \\&= A\otimes \bigl(\begin{bmatrix} B^{(1)} & \dots & B^{(N)} \end{bmatrix}GG^{-1}\begin{bmatrix} C^{(1)} \\ \dots \\ C^{(N)} \end{bmatrix}\bigr) \end{aligned}

A consequence of this framing is, if the B(i)B^{(i)} or C(i)C^{(i)} terms are not lineally independent, then appropriate choice of GG results in a zero in the BB or CC vectors, and a zero corresponds to a rank-reduction. Thus appropriate choice of GG can skip ahead in the search space to find rank-reductions that may require many pairwise flip and plus transforms to set up the needed linear combination.

It remains to discuss how to efficiently determine if B(1),…,B(N)B^{(1)},\dots,B^{(N)} are linealy dependent on GPU. First, it seems to be the case that sets of summands sharing a term are overwhelmingly of size ≀5\leq 5 on the 5Γ—55\times 5 flip graph: in 3,2733,273 schemes sampled from random-walks there were 43,38643,386 sets of more than two summands sharing a term and 99.93%99.93\% had size ≀5\leq 5; Figure 2 visualizes this. Over F2\mathbb{F}_2 determining if B(1),…,B(5)B^{(1)},\dots,B^{(5)} is lineally dependent reduces to determining if the sum (i.e. XOR, i.e. βŠ•\oplus) of any subset is zero. So, 55 is a useful number, because as sets of size 55 have 3131 non-empty subsets and warps have 3232 threads, we can assign each subset to a thread.

Figure 2. Distribution of sizes of sets of summands sharing a term from 3,2733,273 schemes sampled from random-waks on the 5Γ—55\times 5 flip graph.

More precisely, let ss be the index of a thread in a warp and DsD_s be the indices of the set bits of ss. Thread zero is idle while threads s≠0s\neq 0 compute

⨁i∈DsB(i)=0.\bigoplus_{i\in D_s}B^{(i)}=0.

A warp ballot then finds the lowest-index thread for which the above is true. If one exists, call it thread mm and the index of its first set bit r=min⁑Dmr=\min D_m, threads rewrite their summands from

βˆ‘i=0Nβˆ’1AβŠ—B(i)βŠ—C(i)toβˆ‘iβ‰ rAβŠ—B(i)βŠ—Cβ€²(i)\sum_{i=0}^{N-1} A\otimes B^{(i)}\otimes C^{(i)} \quad\text{to}\quad \sum_{i\neq r} A\otimes B^{(i)}\otimes C'^{(i)}

where

Cβ€²(i)={C(i)βŠ•C(r)i∈Dm,iβ‰ rC(i)i∉DmC'^{(i)}=\begin{cases} C^{(i)}\oplus C^{(r)} \quad &i\in D_m, i\neq r \\ C^{(i)} \quad &i\not\in D_m \end{cases}

and the rrth summand is removed, i.e. set to zero.

To see the correctness of this,

βˆ‘iβ‰ rAβŠ—B(i)βŠ—Cβ€²(i)=(βˆ‘iβ‰ rAβŠ—B(i)βŠ—C(i))βŠ•βˆ‘i∈Dm,iβ‰ rAβŠ—B(i)βŠ—C(r)=(βˆ‘iβ‰ rAβŠ—B(i)βŠ—C(i))βŠ•AβŠ—B(r)βŠ—C(r)=βˆ‘i=0Nβˆ’1AβŠ—B(i)βŠ—C(i).\begin{aligned} &\sum_{i\neq r} A\otimes B^{(i)}\otimes C'^{(i)} \\ &= \left(\sum_{i\neq r} A\otimes B^{(i)}\otimes C^{(i)}\right) \oplus \sum_{i\in D_m, i\neq r} A\otimes B^{(i)}\otimes C^{(r)} \\ &= \left(\sum_{i\neq r} A\otimes B^{(i)}\otimes C^{(i)}\right) \oplus A\otimes B^{(r)}\otimes C^{(r)} \\ &= \sum_{i=0}^{N-1} A\otimes B^{(i)}\otimes C^{(i)}. \end{aligned}

Notes

  1. The exposition is on square matrix multiplication but is easily generalized.
  2. At a the necessary level of abstraction for this section, GPUs are composed of many processors composed of threads that all execute the same instruction at the same time. For example, two threads on a processor can simultaneously compute x+yx+y, but one cannot compute x+yx+y while the other computes xβˆ’yx-y, as ++ and βˆ’- are different instructions. Branchy code performs poorly on GPUs because one thread cannot execute the false branch of an if statement while the other executes the true branch. Some primitives, like __match_any_sync, allow threads in a warp to communicate with one another. If you'd like to understand GPU in more detail I made a considerable effort to explain in this post.

References

  1. Flip Graphs for Matrix Multiplication. Manuel Kauers, Jakob Moosbauer. arXiv preprint arXiv:2212.01175. 2022.
  2. Adaptive Flip Graph Algorithm for Matrix Multiplication. Yamato Arai, Yuma Ichikawa, Koji Hukushima. arXiv preprint arXiv:2312.16960. 2023.

Thanks to Ally, Anagha, and Quan.