Together, DSHLNN and FGBDL cover much more than the hardness of
learning parity, so the purpose of this piece is twofold.
To give a slimmed-down version of the papers containing only the
parity result.
To give a version without a page limit, with all the steps
explicit.
Lemma 1, Lemma 3, and Lemma 4 are also left as exercises
for the reader, so I've filled those in. Any illusion of intelligence
found here should be credited to the authors.
Background
Let's start by considering worst-case complexity of learning a parity
function from examples. The parity functions in this piece take as
input a bit string, pick a subset of the bits, then return −1 if the
number of 1s in the subset is odd and 1 otherwise. Formally, the
family of parity functions is
H:={x↦(−1)⟨v,x⟩∣v∈[2]n},
where ⟨v,x⟩:=Σi=1nvixi is the inner
product and n is the input length.
Our task is, given an unknown h∈H and set of examples
{(x(i),h(x(i)))∣0≤i<t}, to determine the value of
h. In the worst case, this takes t>2n−1 examples. Because,
letting xj(i) denote the jth element of x(i), consider
the case where ∃j∀i,xj(i)=0. In this case,
letting h=x↦(−1)⟨v,x⟩ and v′ be the element
of [2]n satisfying vk′=vk↔j=k, for
h′=x↦(−1)⟨v′,x⟩ we have
As h and h′ output the same value on every example input, it is
ambiguous which is the true function based on our examples. There is a
set of examples of size [2]n−1 satisfying ∃j∀i,xj(i)=0, so in the worst case learning h takes t>2n−1
examples.
Shamir's work shows, even when the examples are drawn from a nice
distribution (i.e. don't have the ∃j∀i,xj(i)=0
property), gradient descent can take an exponential (in n) number of
steps to learn some h∈H. In my view, the star of the
show is Theorem 1, which establishes when the gradient of the loss
is similar for all elements of H, the gradient carries
little information about what element of H should actually
be learned. This small ``signal'' is then lost in noise from finite
precision arithmetic and sampling.
Gradient Descent
In the typical machine learning setting, for some target function
h:Rn→Rm and neural network
architecture pw parameterized by weights w∈R∣w∣
we'd like to compute
wminFh(w):=Ex(21∥h(x)−pw(x)∥2),
where Fh(w) is the expected loss given a choice of w.
One approach is to use a variation of gradient descent. This starts
by selecting an initial value for w, call it w(0), then
procedures to iteratively update the weights according to the formula
w(i+1):=w(i)−η(∂w∂Fh)(w(i)),
where η is the learning rate. Intuitively, this works because
(∂w∂Fh)(w(i)) is an element of
R∣w∣ pointing in the direction of steepest increase of
Ex∥h(x)−pw(x)∥2, i.e. the loss. By inverting and
subtracting this value from wi we move the weights in a direction
that decreases the loss.
In practice, computing Ex∥h(x)−pw(x)∥2, and in turn
∂w∂Fh, is computationally infeasible, as
x's distribution is unknown. As such, the standard approach is to
sample x1,x2,…,xt and approximate Fh(w) as
Fh(w)≈Σi∥h(xi)−pw(xi)∥2.
∂w∂Fh can be approximated in turn, and
gradient descent run using the approximation.
Approximate Gradient Descent
We'll model this approximation with two definitions, approximate gradient oracles capture the error involved in computing
∂w∂Fh and approximate gradient-based methods use these approximations.
An Approximate Gradient Oracle is a function
OFh,ϵ:R∣w∣→R∣w∣
satisfying
∀w,∣OFh,ϵ(w)−∂w∂Fh(w)∣≤ϵ.
An Approximate Gradient-Based Method is an algorithm that
generates an initial guess w(0), then decides w(i+1)
based on responses from an approximate gradient oracle.
w(i+1)=f(w(0),{OFh,ϵ(w(i))∣i<i+1})
Parity Hardness
Now, consider a family of functions H={h1,h2,…}
and the variance of the gradient at w with respect to any
h∈H.
To show parity is difficult to learn, we'll show when H
is the family of parity functions, this variance is exponentially
small w.r.t. the length of the parity function's inputs. In turn, an
adversarial approximate gradient oracle can repeatedly return
Eh′((∂w∂Fh′)(w)) instead of
(∂w∂Fh)(w) while staying within its
ϵ error tolerance. Because
Eh′((∂w∂Fh′)(w)) is
independent of the h∈H being learned, an approximate gradient-based method using this adversarial oracle can converge to
a value independent of the target function, h, unless it takes an
exponentially large number of steps.
Theorem 1 (DSHLNN Theorem 10). For some family of functions H, if
∀w,Varh∈H(∂w∂Fh)(w)≤ϵ3
then for any approximate gradient-based method learning
h∈H, there is a run such that the value of
w(⌊ϵ−1⌋) is independent of h.
Proof. By Chebyshev's inequality and the hypothesis, we have
Because Eh((∂w∂Fh)(w)) is
independent of what h is being learned, the inequality above
bounds the likelihood OFh,ϵ(w) is dependent on
h.
Now, for any approximate gradient-based method learning
h∈H, w(0) is independent of h, as nothing has
been sampled from the gradient when it is chosen. As
w(1)=f(w(0),OFh,ϵ(w(0))),
evidently, w(1) is dependent on the h being learned if
OFh,ϵ(w(0)) is, and, per the inequality above, the
likelihood of this is ≤ϵ. Repeating this argument, let
A(i) be the event OFh,ϵ(w(i)) is dependent on
h. We have P(A(i))≤ϵ, so by the union bound
P(i=1⋁IA(i))≤Σi=1IP(A(i))≤Iϵ.
If P(⋁i=1IA(i))<1, then there is an I step
run of our gradient-based method where w(I) is independent of the
target function, h. Solving for I using the equation above gives
the desired bound: If I<1/ϵ, then there is a run of the
gradient-based method where w(⌊ϵ−1⌋) is
independent of h (I am simplifying somewhat here because for the
case we're interested ϵ−1 will not be an integer and
flooring gives the strict < inequality we want, but if you're
feeling picky I=⌈ϵ−1⌉−1 will do).
■
Lemma 1.
Σx∈[n]dΠi=1dfi(xi)=Πi=1dΣx∈[n]fi(x)
Proof. Shamir wordlessly invokes this, but it took me several hours
on an airplane and help from ChatGPT to see. By induction on d. When
d=2,
Proof. Notice g(a):=Ex∥f(x)−a∥2 is (strictly)
convex, so the value of a satisfying (∂a∂g)(a)=0 is a global minimum. Computing this derivative and solving
gives a=Ex(f(x)), as desired. This corresponds to the
intuitive idea that the mean minimizes the expected difference between
it and all other values.
■
Lemma 4 (Bessel's inequality). For a Hilbert space H and finite set {ei∣i∈I} satisfying
⟨ei,ej⟩={10(i=j)(i=j),
i.e. {ei∣i∈I} is an orthonormal family, we have
∀x∈H,Σi∈I∥⟨x,ei⟩∥2≤∥x∥2.
Proof. Intuitively, this says that projecting x onto an
orthonormal basis won't increase its size.Let ai be ⟨x,ei⟩ and y be Σi∈Iaiei, i.e. y is the
projection of x onto the basis formed by {ei∣i∈I}, and
ai is the component of x lying on the ith element of the
basis. Now, let the residual r be x−y and note r is
perpendicular to each ei
Theorem 2 (FGBDL Theorem 1).
Let H be the family of parity functions and let pw
and Fh be a neural network and MSE loss function, as before. If
pw satisfies
∀w,Ex∥(∂w∂pw)(x)∥2≤G(w)2
for some scalar G(w), then
Varh((∂w∂Fh)(w))≤∣H∣G(w)2.
Proof. First, invoking Lemma 3 with
a=Ex(pw(x)(∂w∂pw)(x)), we have
This is a subtle result. For example, consider the family of
functions,
G:={x↦x+10−(100+k)∣0≤k<2n}.
As these differ only by a constant factor, the variance of their
gradient trivially satisfies Theorem 1. In turn, in the worst case
gradient descent might converge to a value independent of the target
function unless we take an exponential number of steps! But this
actually isn't so bad in practice, because the functions are so
similar that following
Eg∈G((∂w∂Fg)(w))
will likely yield a good approximation.
In turn, saying Theorem 1 and Theorem 2 give a result about
the ``hardness'' of learning parity requires subtle
assumptions. Really the result is, unless we take an exponential
number of steps, gradient descent might converge to a value
independent of the target element of H. For parity, this
is probably reasonable, because unlike G parity
functions are quite different.
By inspection, the only two elements of H this can learn
are the trivial case when the subset of bits considered is empty, and
full-parity. In fact, there are only 16 functions from
[2]2→[2], so in this case parity can easily be learned by
brute force.