This is a review of work by Omar Shamir et al. on the difficulty of learning parity with gradient descent. The papers Distribution-Specific Hardness of Learning Neural Networks [1] and Failures of Gradient-Based Deep Learning [2] cover much more than the hardness of learning parity, so this is a slimmed-down review containing only the parity difficulty part.


In SATNet: Bridging deep learning and logical reasoning using a differentiable satisfiability solver [3], Wang et. al. design a continuous relaxation of a SAT solver. Their solver takes as input a CNF formula SS and partial assignment to SS's variables II, and outputs an assignment to the remaining variables, βˆ–I\setminus I, that maximizes satisfied clauses in SS. Because the solver is continuous, given a target set of values for II and βˆ–I\setminus I, one can take the derivative of the difference between the solver's output on II and the target outputs (βˆ–I\setminus I) with respect to SS and use gradient descent to find a CNF formula the satisfiability of which matches the behavior in the target set.

For example, [3] generates many example input/output assignments for if the number of true variables in the input is even or odd and uses gradient descent to find a value for SS with this behavior, namely a CNF formula for the parity function. This is the same idea, in spirit, to how Moitti et. al. use a continuous relaxation of a 2D cellular automata to learn the discrete ruleset for Conway's game of life.

I think this is pretty incredible, but most interesting to me was that [3] cited [2] to invoke a theoretical difficulty learning parity, saying "Learning parity functions from such single-bit supervision is known to pose difficulties for conventional deep learning approaches." At this point I had read neither [1] nor [2], but had read Intractability of Learning the Discrete Logarithm with Gradient-Based Methods [4] by Takhanov et. al., which shows the discrete logarithm is difficult to learn via gradient descent in the same sense that [2] says parity is hard. So, a wrinkle in the literature: [3] claims to synthesize parity with gradient descent despite [4]'s claim such a thing should be intractable.

The purpose of this essay is to smooth this out. In short: [3]'s success learning parity does not contradict [2]'s claim it is difficult, and [2] doesn't prove learning parity is difficult in the colloquial sense. In particular, [2] considers a "family" of parity functions, each of which samples a different subset of a length-nn input and outputs if the number of 11s in the sample is even or odd. Then, [2] shows that when any of these parity functions could be learned by the neural net, there is an exponential-in-nn run of gradient descent that converges to a value independent of the target function (i.e. gradient descent learns nothing about the target function). On the other hand, [3] learns parity with an inductive bias by weaving the input through their model. For example, on a length-44 input x1,…,x4∈[2]x_1,\dots,x_4 \in [2], [3]'s model (mm below) outputs

m(x4,m(x3,m(x2,x1))).m(x_4, m(x_3, m(x_2,x_1))).

By inspection, the only two elements of the family of parity functions this can learn are the trivial case when the subset of bits considered is empty, and full-parity. So while [2] predicts the existence of a worst-case-exponential run, practically speaking, because there are only 1616 functions from [2]2β†’[2][2]^2\rightarrow[2], the constants work out in [3]'s favor.

If this introduction has piqued your curiosity about how [2] analyzes parity difficulty, read on. I think it is a lovely bit of math.

Background

Let's start by considering worst-case complexity of learning a parity function from examples. The parity functions in this piece take as input a bit string, pick a subset of the bits, then return βˆ’1-1 if the number of 11s in the subset is odd and 11 otherwise. Formally, the family of parity functions is

H≔{x↦(βˆ’1)⟨v,x⟩∣v∈[2]n},\mathcal{H}\coloneqq\{x\mapsto (-1)^{\langle v, x\rangle}\mid v\in [2]^n\},

where ⟨v,xβŸ©β‰”Ξ£i=1nvixi\langle v, x\rangle\coloneqq\Sigma_{i=1}^nv_ix_i is the inner product and nn is the input length.

Our task is, given an unknown h∈Hh\in\mathcal{H} and set of examples {(x(i),h(x(i)))∣0≀i<t}\{(x^{(i)},h(x^{(i)}))\mid 0\leq i < t\}, to determine the value of hh. In the worst case, this takes t>2nβˆ’1t>2^{n-1} examples. Because, letting xj(i)x^{(i)}_j denote the jjth element of x(i)x^{(i)}, consider the case where βˆƒjβˆ€i,xj(i)=0\exists j\forall i, x^{(i)}_j=0. In this case, letting h=x↦(βˆ’1)⟨v,x⟩h=x\mapsto(-1)^{\langle v, x\rangle} and vβ€²v' be the element of [2]n[2]^n satisfying vkβ€²β‰ vk↔j=kv'_k\neq v_k\leftrightarrow j=k, for hβ€²=x↦(βˆ’1)⟨vβ€²,x⟩h'=x\mapsto(-1)^{\langle v', x\rangle} we have

βˆ€i,hβ€²(x(i))=(βˆ’1)Ξ£kvkβ€²xk(i)=(βˆ’1)vjβ€²xj(i)+Ξ£kβ‰ jvjβ€²xj(i)=(βˆ’1)0+Ξ£kβ‰ jvjxj(i)=(βˆ’1)vjxj(i)+Ξ£kβ‰ jvjxj(i)=h(x(i)).\begin{aligned} \forall i,\quad h'(x^{(i)}) &= (-1)^{\Sigma_k v'_kx^{(i)}_k} \\ &= (-1)^{v'_jx^{(i)}_j+\Sigma_{k\neq j} v'_jx^{(i)}_j} \\ &= (-1)^{0+\Sigma_{k\neq j} v_jx^{(i)}_j} \\ &= (-1)^{v_jx^{(i)}_j+\Sigma_{k\neq j} v_jx^{(i)}_j} \\ &= h(x^{(i)}). \end{aligned}

As hh and hβ€²h' output the same value on every example input, it is ambiguous which is the true function based on our examples. There is a set of examples of size [2]nβˆ’1[2]^{n-1} satisfying βˆƒjβˆ€i,xj(i)=0\exists j\forall i, x^{(i)}_j=0, so in the worst case learning hh takes t>2nβˆ’1t>2^{n-1} examples.

Shamir's work shows, even when the examples are drawn from a nice distribution (i.e. don't have the βˆƒjβˆ€i,xj(i)=0\exists j\forall i, x^{(i)}_j=0 property), gradient descent can take an exponential (in nn) number of steps to learn some h∈Hh\in\mathcal{H}. In my view, the star of the show is Theorem 1, which establishes when the gradient of the loss is similar for all elements of H\mathcal{H}, the gradient carries little information about what element of H\mathcal{H} should actually be learned. This small signal is lost in noise from finite precision arithmetic and sampling.

Gradient Descent

In the typical machine learning setting, for some target function h:Rnβ†’Rmh:\mathbb{R}^n\rightarrow\mathbb{R}^m and neural network architecture pwp_w parameterized by weights w∈R∣w∣w\in\mathbb{R}^{|w|} we'd like to compute

min⁑wFh(w)≔Ex(12βˆ₯h(x)βˆ’pw(x)βˆ₯2),\min_wF_h(w)\coloneqq\mathbb{E}_x(\frac{1}{2}\|h(x)-p_w(x)\|^2),

where Fh(w)F_h(w) is the expected loss given a choice of ww.

One approach is to use a variation of gradient descent. This starts by selecting an initial value for ww, call it w(0)w^{(0)}, then procedures to iteratively update the weights according to the formula

w(i+1)≔w(i)βˆ’Ξ·(βˆ‚βˆ‚wFh)(w(i)),w^{({i+1})}\coloneqq w^{(i)}-\eta(\frac{\partial}{\partial w}F_h)(w^{(i)}),

where Ξ·\eta is the learning rate. Intuitively, this works because (βˆ‚βˆ‚wFh)(w(i))(\frac{\partial}{\partial w}F_h)(w^{(i)}) is an element of R∣w∣\mathbb{R}^{|w|} pointing in the direction of steepest increase of Exβˆ₯h(x)βˆ’pw(x)βˆ₯2\mathbb{E}_x\|h(x)-p_w(x)\|^2, i.e. the loss. By inverting and subtracting this value from wiw_i we move the weights in a direction that decreases the loss.

In practice, computing Exβˆ₯h(x)βˆ’pw(x)βˆ₯2\mathbb{E}_x\|h(x)-p_w(x)\|^2, and in turn βˆ‚βˆ‚wFh\frac{\partial}{\partial w}F_h, is computationally infeasible, as xx's distribution is unknown. As such, the standard approach is to sample x1,x2,…,xtx_1,x_2,\dots,x_t and approximate Fh(w)F_h(w) as

Fh(w)β‰ˆΞ£iβˆ₯h(xi)βˆ’pw(xi)βˆ₯2.F_h(w) \approx \Sigma_i \|h(x_i)-p_w(x_i)\|^2.

βˆ‚βˆ‚wFh\frac{\partial}{\partial w}F_h can be approximated in turn, and gradient descent run using the approximation.

Approximate Gradient Descent

We'll model this approximation with two definitions, approximate gradient oracles capture the error involved in computing βˆ‚βˆ‚wFh\frac{\partial}{\partial w}F_h and approximate gradient-based methods use these approximations.

An Approximate Gradient Oracle is a function OFh,Ο΅:R∣wβˆ£β†’R∣w∣O_{F_h,\epsilon}:\mathbb{R}^{|w|}\rightarrow \mathbb{R}^{|w|} satisfying

βˆ€w,∣OFh,Ο΅(w)βˆ’βˆ‚βˆ‚wFh(w)βˆ£β‰€Ο΅.\forall w, |O_{F_h,\epsilon}(w) - \frac{\partial}{\partial w}F_h(w)| \leq \epsilon.

An Approximate Gradient-Based Method is an algorithm that generates an initial guess w(0)w^{(0)}, then decides w(i+1)w^{(i+1)} based on responses from an approximate gradient oracle.

w(i+1)=f(w(0),{OFh,ϡ(w(i))∣i<i+1})w^{(i+1)} = f(w^{(0)}, \{O_{F_h,\epsilon}(w^{(i)}) \mid i<i+1\})

Parity Hardness

Now, consider a family of functions H={h1,h2,… }\mathcal{H}=\{h_1,h_2,\dots\} and the variance of the gradient at ww with respect to any h∈Hh\in\mathcal{H}.

Varh∈H(βˆ‚βˆ‚wFh(w))≔Ehβˆ₯(βˆ‚βˆ‚wFh)(w)βˆ’Ehβ€²((βˆ‚βˆ‚wFhβ€²)(w))βˆ₯2\text{Var}_{h\in\mathcal{H}}(\frac{\partial}{\partial w}F_h(w))\coloneqq \mathbb{E}_h\| (\frac{\partial}{\partial w}F_h)(w) - \mathbb{E}_{h'}((\frac{\partial}{\partial w}F_{h'})(w)) \|^2

To show parity is difficult to learn, we'll show when H\mathcal{H} is the family of parity functions, this variance is exponentially small w.r.t. the length of the parity function's inputs. In turn, an adversarial approximate gradient oracle can repeatedly return Ehβ€²((βˆ‚βˆ‚wFhβ€²)(w))\mathbb{E}_{h'}((\frac{\partial}{\partial w}F_{h'})(w)) instead of (βˆ‚βˆ‚wFh)(w)(\frac{\partial}{\partial w}F_h)(w) while staying within its Ο΅\epsilon error tolerance. Because Ehβ€²((βˆ‚βˆ‚wFhβ€²)(w))\mathbb{E}_{h'}((\frac{\partial}{\partial w}F_{h'})(w)) is independent of the h∈Hh\in\mathcal{H} being learned, an approximate gradient-based method using this adversarial oracle can converge to a value independent of the target function, hh, unless it takes an exponentially large number of steps.

Theorem 1 (DSHLNN Theorem 10). For some family of functions H\mathcal{H}, if

βˆ€w,Varh∈H(βˆ‚βˆ‚wFh)(w)≀ϡ3\forall w, \text{Var}_{h\in\mathcal{H}}(\frac{\partial}{\partial w}F_h)(w) \leq \epsilon^3

then for any approximate gradient-based method learning h∈Hh\in\mathcal{H}, there is a run such that the value of w(βŒŠΟ΅βˆ’1βŒ‹)w^{(\lfloor \epsilon^{-1}\rfloor)} is independent of hh.

Proof. By Chebyshev's inequality and the hypothesis, we have

βˆ€w,Ph(βˆ₯(βˆ‚βˆ‚wFh)(w)βˆ’Eh((βˆ‚βˆ‚wFh)(w))βˆ₯>Ο΅)≀Varh((βˆ‚βˆ‚wFh)(w))/Ο΅2≀ϡ.\begin{aligned} \forall w, \quad &\mathbb{P}_h(\|(\frac{\partial}{\partial w}F_h)(w)-\mathbb{E}_h((\frac{\partial}{\partial w}F_h)(w))\| > \epsilon) \\ &\leq \text{Var}_h((\frac{\partial}{\partial w}F_h)(w))/\epsilon^2 \\ &\leq \epsilon. \end{aligned}

So, consider the adversarial approximate gradient oracle

OFh,Ο΅(w)≔{Eh((βˆ‚βˆ‚wFh)(w))ifΒ βˆ₯(βˆ‚βˆ‚wFh)(w)βˆ’Eh((βˆ‚βˆ‚wFh)(w))βˆ₯≀ϡ(βˆ‚βˆ‚wFh)(w)otherwise.O_{F_h,\epsilon}(w) \coloneqq \begin{cases} \mathbb{E}_h((\frac{\partial}{\partial w}F_h)(w)) &\text{if }\|(\frac{\partial}{\partial w}F_h)(w)-\mathbb{E}_h((\frac{\partial}{\partial w}F_h)(w))\| \leq \epsilon \\ (\frac{\partial}{\partial w}F_h)(w) &\text{otherwise.} \end{cases}

Using our Chebyshev inequality, we can bound the likelihood of OFh,Ο΅O_{F_h,\epsilon}'s otherwise\text{otherwise} case.

βˆ€w,Ph(OFh,Ο΅(w)=(βˆ‚βˆ‚wFh)(w))=Ph(βˆ₯(βˆ‚βˆ‚wFh)(w)βˆ’Eh((βˆ‚βˆ‚wFh)(w))βˆ₯>Ο΅)≀ϡ.\begin{aligned} \forall w, &\mathbb{P}_h(O_{F_h,\epsilon}(w)=(\frac{\partial}{\partial w}F_h)(w)) \\ &= \mathbb{P}_h(\|(\frac{\partial}{\partial w}F_h)(w)-\mathbb{E}_h((\frac{\partial}{\partial w}F_h)(w))\| > \epsilon) \\ &\leq \epsilon. \end{aligned}

Because Eh((βˆ‚βˆ‚wFh)(w))\mathbb{E}_h((\frac{\partial}{\partial w}F_h)(w)) is independent of what hh is being learned, the inequality above bounds the likelihood OFh,Ο΅(w)O_{F_h,\epsilon}(w) is dependent on hh.

Now, for any approximate gradient-based method learning h∈Hh\in\mathcal{H}, w(0)w^{(0)} is independent of hh, as nothing has been sampled from the gradient when it is chosen. As

w(1)=f(w(0),OFh,Ο΅(w(0))),w^{(1)}=f(w^{(0)}, O_{F_h,\epsilon}(w^{(0)})),

evidently, w(1)w^{(1)} is dependent on the hh being learned if OFh,Ο΅(w(0))O_{F_h,\epsilon}(w^{(0)}) is, and, per the inequality above, the likelihood of this is ≀ϡ\leq\epsilon. Repeating this argument, let A(i)A^{(i)} be the event OFh,Ο΅(w(i))O_{F_h,\epsilon}(w^{(i)}) is dependent on hh. We have P(A(i))≀ϡ\mathbb{P}(A^{(i)})\leq\epsilon, so by the union bound

P(⋁i=1IA(i))≀Σi=1IP(A(i))≀IΟ΅.\begin{aligned} \mathbb{P}(\bigvee_{i=1}^IA^{(i)})&\leq\Sigma_{i=1}^I\mathbb{P}(A^{(i)})\\ &\leq I\epsilon. \end{aligned}

If P(⋁i=1IA(i))<1\mathbb{P}(\bigvee_{i=1}^IA^{(i)})<1, then there is an II step run of our gradient-based method where w(I)w^{(I)} is independent of the target function, hh. Solving for II using the equation above gives the desired bound: If I<1/Ο΅I<1/\epsilon, then there is a run of the gradient-based method where w(βŒŠΟ΅βˆ’1βŒ‹)w^{(\lfloor \epsilon^{-1} \rfloor)} is independent of hh (I am simplifying somewhat here because for the case we're interested Ο΅βˆ’1\epsilon^{-1} will not be an integer and flooring gives the strict << inequality we want, but if you're feeling picky I=βŒˆΟ΅βˆ’1βŒ‰βˆ’1I=\lceil \epsilon^{-1}\rceil -1 will do).

β– \blacksquare

Lemma 1.

Σx∈[n]dΠi=1dfi(xi)=Πi=1dΣx∈[n]fi(x)\Sigma_{x\in[n]^d}\Pi_{i=1}^d f_i(x_i) = \Pi_{i=1}^d\Sigma_{x\in[n]}f_i(x)

Proof. Shamir wordlessly invokes this, but it took me several hours on an airplane and help from ChatGPT to see. By induction on dd. When d=2d=2,

Σx∈[n]2Πi=12fi(xi)=Σx∈[n]2f1(x1)f2(x2)=Σx1∈[n]Σx2∈[n]f1(x1)f2(x2)=Σx1∈[n]f1(x1)(Σx2∈[n]f2(x2))=(Σx1∈[n]f1(x1))(Σx2∈[n]f2(x2))=Πi=12Σx∈[n]fi(x).\begin{aligned} \Sigma_{x\in[n]^2}\Pi_{i=1}^2 f_i(x_i) &= \Sigma_{x\in[n]^2}f_1(x_1)f_2(x_2) \\ &= \Sigma_{x_1\in[n]}\Sigma_{x_2\in[n]}f_1(x_1)f_2(x_2) \\ &= \Sigma_{x_1\in[n]}f_1(x_1)(\Sigma_{x_2\in[n]}f_2(x_2)) \\ &= (\Sigma_{x_1\in[n]}f_1(x_1))(\Sigma_{x_2\in[n]}f_2(x_2)) \\ &= \Pi_{i=1}^2\Sigma_{x\in[n]}f_i(x). \end{aligned}

For the inductive step,

Ξ£x∈[n]dΞ i=1dfi(xi)=Ξ£x1∈[n]Ξ£x2∈[n]…Σxd∈[n]Ξ i=1dfi(xi)=Ξ£xd∈[n](Ξ£x1∈[n]…Σxdβˆ’1∈[n]Ξ i=1dβˆ’1fi(xi)⏟inductiveΒ hypothesis)fd(xd)=Ξ£xd∈[n](Ξ i=1dβˆ’1Ξ£x∈[n]fi(x))fd(xd)=(Ξ i=1dβˆ’1Ξ£x∈[n]fi(x))(Ξ£xd∈[n]fd(xd))=Ξ i=1dΞ£x∈[n]fi(x).\begin{aligned} \Sigma_{x\in[n]^d}\Pi_{i=1}^df_i(x_i) &= \Sigma_{x_1\in[n]}\Sigma_{x_2\in[n]}\dots\Sigma_{x_d\in[n]}\Pi_{i=1}^df_i(x_i) \\ &= \Sigma_{x_d\in[n]}(\underbrace{\Sigma_{x_1\in[n]}\dots\Sigma_{x_{d-1}\in[n]}\Pi_{i=1}^{d-1}f_i(x_i)}_{\substack{\text{inductive hypothesis}}})f_d(x_d) \\ &= \Sigma_{x_d\in[n]}(\Pi_{i=1}^{d-1}\Sigma_{x\in[n]}f_i(x))f_d(x_d) \\ &= (\Pi_{i=1}^{d-1}\Sigma_{x\in[n]}f_i(x))(\Sigma_{x_d\in[n]}f_d(x_d)) \\ &= \Pi_{i=1}^d\Sigma_{x\in[n]}f_i(x). \end{aligned}

β– \blacksquare

Lemma 2. For the family of parity functions, H\mathcal{H}, and any h,hβ€²βˆˆHh,h'\in\mathcal{H}

Ex(h(x)h′(x))={1(h=h′)0(h≠h′)\mathbb{E}_x(h(x)h'(x))=\begin{cases} 1 & (h=h') \\ 0 & (h\neq h')\end{cases}

Proof.

Ex(h(x)hβ€²(x))=Ex((βˆ’1)⟨v,x⟩(βˆ’1)⟨vβ€²,x⟩)=Ex((βˆ’1)⟨v+vβ€²,x⟩)=Ex(Ξ i=1n(βˆ’1)(vi+viβ€²)xi)=12nΞ£x∈[2]nΞ i=1n(βˆ’1)(vi+viβ€²)xi=12nΞ i=1nΞ£x∈[2](βˆ’1)(vi+viβ€²)x(LemmaΒ 1)=12nΞ i=1n212Ξ£x∈[2](βˆ’1)(vi+viβ€²)x=Ξ i=1nExi((βˆ’1)(vi+viβ€²)xi)=Ξ i=1n((βˆ’1)0β‹…(vi+viβ€²)+(βˆ’1)1β‹…(vi+viβ€²))/2=Ξ i=1n(1+(βˆ’1)vi+viβ€²)/2.\begin{aligned} \mathbb{E}_x(h(x)h'(x)) &= \mathbb{E}_x((-1)^{\langle v,x\rangle}(-1)^{\langle v',x\rangle}) \\ &= \mathbb{E}_x((-1)^{\langle v+v',x\rangle}) \\ &= \mathbb{E}_x(\Pi_{i=1}^n (-1)^{(v_i+v'_i)x_i}) \\ &= \frac{1}{2^n}\Sigma_{x\in[2]^n}\Pi_{i=1}^n (-1)^{(v_i+v'_i)x_i} \\ &= \frac{1}{2^n}\Pi_{i=1}^n\Sigma_{x\in[2]}(-1)^{(v_i+v'_i)x} \quad (\text{Lemma }1) \\ &= \frac{1}{2^n}\Pi_{i=1}^n2\frac{1}{2}\Sigma_{x\in[2]}(-1)^{(v_i+v'_i)x} \\ &= \Pi_{i=1}^n\mathbb{E}_{x_i}((-1)^{(v_i+v'_i)x_i}) \\ &= \Pi_{i=1}^n((-1)^{0\cdot(v_i+v'_i)} + (-1)^{1\cdot(v_i+v'_i)})/2 \\ &= \Pi_{i=1}^n(1 + (-1)^{v_i+v'_i})/2. \end{aligned}

β– \blacksquare

Lemma 3. For any nn, mm, and f:Rn→Rmf:\mathbb{R}^n\rightarrow\mathbb{R}^m,

βˆ€a∈Rm,Exβˆ₯f(x)βˆ’Ex(f(x))βˆ₯2≀Exβˆ₯f(x)βˆ’aβˆ₯2.\forall a\in\mathbb{R}^m, \mathbb{E}_x\|f(x)-\mathbb{E}_x(f(x))\|^2 \leq \mathbb{E}_x\|f(x)-a\|^2.

Proof. Notice g(a)≔Exβˆ₯f(x)βˆ’aβˆ₯2g(a)\coloneqq\mathbb{E}_x\|f(x)-a\|^2 is (strictly) convex, so the value of aa satisfying (βˆ‚βˆ‚ag)(a)=0(\frac{\partial}{\partial a}g)(a)=0 is a global minimum. Computing this derivative and solving gives a=Ex(f(x))a=\mathbb{E}_x(f(x)), as desired. This corresponds to the intuitive idea that the mean minimizes the expected difference between it and all other values.

β– \blacksquare

Lemma 4 (Bessel's inequality). For a Hilbert space HH and finite set {ei∣i∈I}\{e_i\mid i\in I\} satisfying

⟨ei,ej⟩={1(i=j)0(iβ‰ j),\langle e_i,e_j\rangle=\begin{cases}1 &(i=j)\\0 &(i\neq j)\end{cases},

i.e. {ei∣i∈I}\{e_i\mid i\in I\} is an orthonormal family, we have

βˆ€x∈H,Ξ£i∈Iβˆ₯⟨x,ei⟩βˆ₯2≀βˆ₯xβˆ₯2.\forall x\in H, \Sigma_{i\in I}\|\langle x, e_i\rangle\|^2\leq\|x\|^2.

Proof. Intuitively, this says that projecting xx onto an orthonormal basis won't increase its size.Let aia_i be ⟨x,ei⟩\langle x, e_i\rangle and yy be Ξ£i∈Iaiei\Sigma_{i\in I}a_i e_i, i.e. yy is the projection of xx onto the basis formed by {ei∣i∈I}\{e_i\mid i\in I\}, and aia_i is the component of xx lying on the iith element of the basis. Now, let the residual rr be xβˆ’yx-y and note rr is perpendicular to each eie_i

⟨r,ei⟩=⟨x,eiβŸ©βˆ’βŸ¨y,ei⟩=aiβˆ’Ξ£j∈Iaj⟨ej,ei⟩=aiβˆ’ai=0.\begin{aligned} \langle r, e_i\rangle &= \langle x, e_i\rangle - \langle y, e_i\rangle \\ &= a_i - \Sigma_{j\in I}a_j\langle e_j, e_i\rangle \\ &= a_i - a_i \\ &= 0. \end{aligned}

In turn, rr is perpendicular to yy,

⟨r,y⟩=Σi∈Iai⟨r,ei⟩=0.\langle r, y\rangle = \Sigma_{i\in I}a_i\langle r, e_i\rangle = 0.

So, because x=y+rx=y+r and yy and rr are perpendicular, by Pythagoras',

βˆ₯xβˆ₯2=βˆ₯y+rβˆ₯2=βˆ₯yβˆ₯2+βˆ₯rβˆ₯2β‰₯βˆ₯yβˆ₯2=⟨Σi∈Iaiei,Ξ£i∈Iaiei⟩=Ξ£i∈Iai2=Ξ£i∈Iβˆ₯⟨x,ei⟩βˆ₯2\begin{aligned} \|x\|^2&=\|y+r\|^2 \\ &=\|y\|^2+\|r\|^2 \\ &\geq \|y\|^2 \\ &= \langle \Sigma_{i\in I} a_ie_i, \Sigma_{i\in I} a_ie_i\rangle \\ &= \Sigma_{i\in I}a_i^2 \\ &= \Sigma_{i\in I}\|\langle x, e_i\rangle\|^2 \end{aligned}

β– \blacksquare

Theorem 2 (FGBDL Theorem 1). Let H\mathcal{H} be the family of parity functions and let pwp_w and FhF_h be a neural network and MSE loss function, as before. If pwp_w satisfies

βˆ€w,Exβˆ₯(βˆ‚βˆ‚wpw)(x)βˆ₯2≀G(w)2\forall w, \mathbb{E}_x\|(\frac{\partial}{\partial w}p_w)(x)\|^2\leq G(w)^2

for some scalar G(w)G(w), then

Varh((βˆ‚βˆ‚wFh)(w))≀G(w)2∣H∣.\text{Var}_h((\frac{\partial}{\partial w}F_h)(w))\leq\frac{G(w)^2}{|\mathcal{H}|}.

Proof. First, invoking Lemma 3 with a=Ex(pw(x)(βˆ‚βˆ‚wpw)(x))a=\mathbb{E}_{x}(p_w(x)(\frac{\partial}{\partial w}p_w)(x)), we have

Varh((βˆ‚βˆ‚wFh)(w))=Ehβˆ₯(βˆ‚βˆ‚wFh)(w)βˆ’Ehβ€²((βˆ‚βˆ‚wFhβ€²)(w))βˆ₯2≀Ehβˆ₯(βˆ‚βˆ‚wFh)(w)βˆ’Ex(pw(x)(βˆ‚βˆ‚wpw)(x))βˆ₯2=Ehβˆ₯Ex((pw(x)βˆ’h(x))(βˆ‚βˆ‚wpw)(x))βˆ’Ex(pw(x)(βˆ‚βˆ‚wpw)(x))βˆ₯2=Ehβˆ₯Ex(h(x)(βˆ‚βˆ‚wpw)(x))βˆ₯2.\begin{aligned} &\text{Var}_h((\frac{\partial}{\partial w}F_h)(w)) \\ &= \mathbb{E}_h\| (\frac{\partial}{\partial w}F_h)(w) - \mathbb{E}_{h'}((\frac{\partial}{\partial w}F_{h'})(w)) \|^2 \\ &\leq \mathbb{E}_h\| (\frac{\partial}{\partial w}F_h)(w) - \mathbb{E}_{x}(p_w(x)(\frac{\partial}{\partial w}p_w)(x)) \|^2 \\ &= \mathbb{E}_h\|\mathbb{E}_x((p_w(x)-h(x))(\frac{\partial}{\partial w}p_w)(x)) - \mathbb{E}_{x}(p_w(x)(\frac{\partial}{\partial w}p_w)(x))\|^2 \\ &= \mathbb{E}_h\|\mathbb{E}_x(h(x)(\frac{\partial}{\partial w}p_w)(x))\|^2. \end{aligned}

Next, note

Ex(h(x)(βˆ‚βˆ‚wpw)(x))=12nΞ£x∈[2]nh(x)(βˆ‚βˆ‚wpw)(x)=12n⟨h,βˆ‚βˆ‚wpw⟩.\begin{aligned} \mathbb{E}_x(h(x)(\frac{\partial}{\partial w}p_w)(x))&=\frac{1}{2^n}\Sigma_{x\in[2]^n}h(x)(\frac{\partial}{\partial w}p_w)(x) \\ &= \frac{1}{2^n}\langle h, \frac{\partial}{\partial w}p_w\rangle. \end{aligned}

So, using Lemma 2 to invoke Lemma 4 we get the desired bound.

Ehβˆ₯Ex(h(x)(βˆ‚βˆ‚wpw)(x))βˆ₯2=EhΞ£i=1∣w∣Ex(h(x)(βˆ‚βˆ‚wpw)i(x))2=Ξ£i=1∣w∣Eh(12n⟨h,(βˆ‚βˆ‚wpw)i⟩)=Ξ£i=1∣w∣12n21∣H∣Σh∈H⟨h,(βˆ‚βˆ‚wpw)iβŸ©β‰€Ξ£i=1∣w∣12n1∣H∣βˆ₯(βˆ‚βˆ‚wpw)iβˆ₯2(Bessel’sΒ inequality)=Ξ£i=1∣w∣1∣H∣Exβˆ₯(βˆ‚βˆ‚wpw)i(x)βˆ₯2=1∣H∣Exβˆ₯(βˆ‚βˆ‚wpw)(x)βˆ₯2≀G(w)2∣H∣.\begin{aligned} &\mathbb{E}_h\|\mathbb{E}_x(h(x)(\frac{\partial}{\partial w}p_w)(x))\|^2 \\ &= \mathbb{E}_h\Sigma_{i=1}^{|w|}\mathbb{E}_x(h(x)(\frac{\partial}{\partial w}p_w)_i(x))^2 \\ &= \Sigma_{i=1}^{|w|}\mathbb{E}_h(\frac{1}{2^n}\langle h, (\frac{\partial}{\partial w}p_w)_i\rangle) \\ &= \Sigma_{i=1}^{|w|}\frac{1}{2^n}^2\frac{1}{|\mathcal{H}|}\Sigma_{h\in\mathcal{H}}\langle h, (\frac{\partial}{\partial w}p_w)_i\rangle \\ &\leq \Sigma_{i=1}^{|w|}\frac{1}{2^n}\frac{1}{|\mathcal{H}|}\|(\frac{\partial}{\partial w}p_w)_i\|^2 \quad \text{(Bessel's inequality)} \\ &= \Sigma_{i=1}^{|w|}\frac{1}{|\mathcal{H}|}\mathbb{E}_x\|(\frac{\partial}{\partial w}p_w)_i(x)\|^2 \\ &= \frac{1}{|\mathcal{H}|}\mathbb{E}_x\|(\frac{\partial}{\partial w}p_w)(x)\|^2 \\ &\leq \frac{G(w)^2}{|\mathcal{H}|}. \end{aligned}

β– \blacksquare

Theorem 2 bounds the variance of the gradient of the family of parity functions, so together with Theorem 1 we get our goal result on the hardness of learning parity. For another cool use of Theorem 1 see Intractability of Learning the Discrete Logarithm with Gradient-Based Methods by Takhanov et. al..