In SATNet: Bridging deep learning and logical reasoning using a
differentiable satisfiability
solver [3], Wang
et. al. design a continuous relaxation of a SAT solver. Their solver
takes as input a CNF formula S and partial assignment to S's
variables I, and outputs an assignment to the remaining variables,
βI, that maximizes satisfied clauses in S. Because the
solver is continuous, given a target set of values for I and
βI, one can take the derivative of the difference between
the solver's output on I and the target outputs (βI) with
respect to S and use gradient descent to find a CNF formula the
satisfiability of which matches the behavior in the target set.
For example, [3]generates
many example input/output assignments for if the number of true
variables in the input is even or odd and uses gradient descent to
find a value for S with this behavior, namely a CNF formula for the
parity function. This
is the same idea, in spirit, to how Moitti
et. al.
use a continuous relaxation of a 2D cellular automata to learn the
discrete ruleset for Conway's game of life.
I think this is pretty incredible, but most interesting to me was that
[3] cited [2] to invoke a theoretical difficulty learning
parity, saying "Learning parity functions from such single-bit
supervision is known to pose difficulties for conventional deep
learning approaches." At this point I had read neither [1] nor
[2], but had read Intractability of Learning the Discrete
Logarithm with Gradient-Based
Methods [4] by
Takhanov et. al., which shows the discrete logarithm is difficult to
learn via gradient descent in the same sense that [2] says parity
is hard. So, a wrinkle in the literature: [3] claims to synthesize
parity with gradient descent despite [4]'s claim such a thing
should be intractable.
The purpose of this essay is to smooth this out. In short: [3]'s
success learning parity does not contradict [2]'s claim it is
difficult, and [2] doesn't prove learning parity is difficult in
the colloquial sense. In particular, [2] considers a "family" of
parity functions, each of which samples a different subset of a
length-n input and outputs if the number of 1s in the sample is
even or odd. Then, [2] shows that when any of these parity
functions could be learned by the neural net, there is an
exponential-in-n run of gradient descent that converges to a value
independent of the target function (i.e. gradient descent learns
nothing about the target function). On the other hand, [3] learns
parity with an inductive bias by
weaving
the input through their model. For example, on a length-4 input
x1β,β¦,x4ββ[2], [3]'s model (m below) outputs
m(x4β,m(x3β,m(x2β,x1β))).
By inspection, the only two elements of the family of parity functions
this can learn are the trivial case when the subset of bits considered
is empty, and full-parity. So while [2] predicts the existence of
a worst-case-exponential run, practically speaking, because there are
only 16 functions from [2]2β[2], the constants work out
in [3]'s favor.
If this introduction has piqued your curiosity about how [2]
analyzes parity difficulty, read on. I think it is a lovely bit of
math.
Background
Let's start by considering worst-case complexity of learning a parity
function from examples. The parity functions in this piece take as
input a bit string, pick a subset of the bits, then return β1 if the
number of 1s in the subset is odd and 1 otherwise. Formally, the
family of parity functions is
As h and hβ² output the same value on every example input, it is
ambiguous which is the true function based on our examples. There is a
set of examples of size [2]nβ1 satisfying βjβi,xj(i)β=0, so in the worst case learning h takes t>2nβ1
examples.
Shamir's work shows, even when the examples are drawn from a nice
distribution (i.e. don't have the βjβi,xj(i)β=0
property), gradient descent can take an exponential (in n) number of
steps to learn some hβH. In my view, the star of the
show is Theorem 1, which establishes when the gradient of the loss
is similar for all elements of H, the gradient carries
little information about what element of H should actually
be learned. This small signal is lost in noise from finite
precision arithmetic and sampling.
Gradient Descent
In the typical machine learning setting, for some target function
h:RnβRm and neural network
architecture pwβ parameterized by weights wβRβ£wβ£
we'd like to compute
where Fhβ(w) is the expected loss given a choice of w.
One approach is to use a variation of gradient descent. This starts
by selecting an initial value for w, call it w(0), then
procedures to iteratively update the weights according to the formula
w(i+1):=w(i)βΞ·(βwββFhβ)(w(i)),
where Ξ· is the learning rate. Intuitively, this works because
(βwββFhβ)(w(i)) is an element of
Rβ£wβ£ pointing in the direction of steepest increase of
Exββ₯h(x)βpwβ(x)β₯2, i.e. the loss. By inverting and
subtracting this value from wiβ we move the weights in a direction
that decreases the loss.
In practice, computing Exββ₯h(x)βpwβ(x)β₯2, and in turn
βwββFhβ, is computationally infeasible, as
x's distribution is unknown. As such, the standard approach is to
sample x1β,x2β,β¦,xtβ and approximate Fhβ(w) as
Fhβ(w)βΞ£iββ₯h(xiβ)βpwβ(xiβ)β₯2.
βwββFhβ can be approximated in turn, and
gradient descent run using the approximation.
Approximate Gradient Descent
We'll model this approximation with two definitions, approximate gradient oracles capture the error involved in computing
βwββFhβ and approximate gradient-based methods use these approximations.
An Approximate Gradient Oracle is a function
OFhβ,Ο΅β:Rβ£wβ£βRβ£wβ£
satisfying
An Approximate Gradient-Based Method is an algorithm that
generates an initial guess w(0), then decides w(i+1)
based on responses from an approximate gradient oracle.
w(i+1)=f(w(0),{OFhβ,Ο΅β(w(i))β£i<i+1})
Parity Hardness
Now, consider a family of functions H={h1β,h2β,β¦}
and the variance of the gradient at w with respect to any
hβH.
To show parity is difficult to learn, we'll show when H
is the family of parity functions, this variance is exponentially
small w.r.t. the length of the parity function's inputs. In turn, an
adversarial approximate gradient oracle can repeatedly return
Ehβ²β((βwββFhβ²β)(w)) instead of
(βwββFhβ)(w) while staying within its
Ο΅ error tolerance. Because
Ehβ²β((βwββFhβ²β)(w)) is
independent of the hβH being learned, an approximate gradient-based method using this adversarial oracle can converge to
a value independent of the target function, h, unless it takes an
exponentially large number of steps.
Theorem 1 (DSHLNN Theorem 10). For some family of functions H, if
βw,VarhβHβ(βwββFhβ)(w)β€Ο΅3
then for any approximate gradient-based method learning
hβH, there is a run such that the value of
w(βΟ΅β1β) is independent of h.
Proof. By Chebyshev's inequality and the hypothesis, we have
Because Ehβ((βwββFhβ)(w)) is
independent of what h is being learned, the inequality above
bounds the likelihood OFhβ,Ο΅β(w) is dependent on
h.
Now, for any approximate gradient-based method learning
hβH, w(0) is independent of h, as nothing has
been sampled from the gradient when it is chosen. As
w(1)=f(w(0),OFhβ,Ο΅β(w(0))),
evidently, w(1) is dependent on the h being learned if
OFhβ,Ο΅β(w(0)) is, and, per the inequality above, the
likelihood of this is β€Ο΅. Repeating this argument, let
A(i) be the event OFhβ,Ο΅β(w(i)) is dependent on
h. We have P(A(i))β€Ο΅, so by the union bound
P(i=1βIβA(i))ββ€Ξ£i=1IβP(A(i))β€IΟ΅.β
If P(βi=1IβA(i))<1, then there is an I step
run of our gradient-based method where w(I) is independent of the
target function, h. Solving for I using the equation above gives
the desired bound: If I<1/Ο΅, then there is a run of the
gradient-based method where w(βΟ΅β1β) is
independent of h (I am simplifying somewhat here because for the
case we're interested Ο΅β1 will not be an integer and
flooring gives the strict < inequality we want, but if you're
feeling picky I=βΟ΅β1ββ1 will do).
Proof. Notice g(a):=Exββ₯f(x)βaβ₯2 is (strictly)
convex, so the value of a satisfying (βaββg)(a)=0 is a global minimum. Computing this derivative and solving
gives a=Exβ(f(x)), as desired. This corresponds to the
intuitive idea that the mean minimizes the expected difference between
it and all other values.
β
Lemma 4 (Bessel's inequality). For a Hilbert space H and finite set {eiββ£iβI} satisfying
Theorem 2 (FGBDL Theorem 1).
Let H be the family of parity functions and let pwβ
and Fhβ be a neural network and MSE loss function, as before. If
pwβ satisfies
βw,Exββ₯(βwββpwβ)(x)β₯2β€G(w)2
for some scalar G(w), then
Varhβ((βwββFhβ)(w))β€β£Hβ£G(w)2β.
Proof. First, invoking Lemma 3 with
a=Exβ(pwβ(x)(βwββpwβ)(x)), we have