In SATNet: Bridging deep learning and logical reasoning using a
differentiable satisfiability
solver [3], Wang
et. al. design a continuous relaxation of a SAT solver. Their solver
takes as input a CNF formula S and partial assignment to S's
variables I, and outputs an assignment to the remaining variables,
∖I, that maximizes satisfied clauses in S. Because the
solver is continuous, given a target set of values for I and
∖I, one can take the derivative of the difference between
the solver's output on I and the target outputs (∖I) with
respect to S and use gradient descent to find a CNF formula the
satisfiability of which matches the behavior in the target set.
For example, [3]generates
many example input/output assignments for if the number of true
variables in the input is even or odd and uses gradient descent to
find a value for S with this behavior, namely a CNF formula for the
parity function. This
is the same idea, in spirit, to how Moitti
et. al.
use a continuous relaxation of a 2D cellular automata to learn the
discrete ruleset for Conway's game of life.
I think this is pretty incredible, but most interesting to me was that
[3] cited [2] to invoke a theoretical difficulty learning
parity, saying "Learning parity functions from such single-bit
supervision is known to pose difficulties for conventional deep
learning approaches." At this point I had read neither [1] nor
[2], but had read Intractability of Learning the Discrete
Logarithm with Gradient-Based
Methods [4] by
Takhanov et. al., which shows the discrete logarithm is difficult to
learn via gradient descent in the same sense that [2] says parity
is hard. So, a wrinkle in the literature: [3] claims to synthesize
parity with gradient descent despite [4]'s claim such a thing
should be intractable.
The purpose of this essay is to smooth this out. In short: [3]'s
success learning parity does not contradict [2]'s claim it is
difficult, and [2] doesn't prove learning parity is difficult in
the colloquial sense. In particular, [2] considers a "family" of
parity functions, each of which samples a different subset of a
length-n input and outputs if the number of 1s in the sample is
even or odd. Then, [2] shows that when any of these parity
functions could be learned by the neural net, there is an
exponential-in-n run of gradient descent that converges to a value
independent of the target function (i.e. gradient descent learns
nothing about the target function). On the other hand, [3] learns
parity with an inductive bias by
weaving
the input through their model. For example, on a length-4 input
x1,…,x4∈[2], [3]'s model (m below) outputs
m(x4,m(x3,m(x2,x1))).
By inspection, the only two elements of the family of parity functions
this can learn are the trivial case when the subset of bits considered
is empty, and full-parity. So while [2] predicts the existence of
a worst-case-exponential run, practically speaking, because there are
only 16 functions from [2]2→[2], the constants work out
in [3]'s favor.
If this introduction has piqued your curiosity about how [2]
analyzes parity difficulty, read on. I think it is a lovely bit of
math.
Background
Let's start by considering worst-case complexity of learning a parity
function from examples. The parity functions in this piece take as
input a bit string, pick a subset of the bits, then return −1 if the
number of 1s in the subset is odd and 1 otherwise. Formally, the
family of parity functions is
H:={x↦(−1)⟨v,x⟩∣v∈[2]n},
where ⟨v,x⟩:=Σi=1nvixi is the inner
product and n is the input length.
Our task is, given an unknown h∈H and set of examples
{(x(i),h(x(i)))∣0≤i<t}, to determine the value of
h. In the worst case, this takes t>2n−1 examples. Because,
letting xj(i) denote the jth element of x(i), consider
the case where ∃j∀i,xj(i)=0. In this case,
letting h=x↦(−1)⟨v,x⟩ and v′ be the element
of [2]n satisfying vk′=vk↔j=k, for
h′=x↦(−1)⟨v′,x⟩ we have
As h and h′ output the same value on every example input, it is
ambiguous which is the true function based on our examples. There is a
set of examples of size [2]n−1 satisfying ∃j∀i,xj(i)=0, so in the worst case learning h takes t>2n−1
examples.
Shamir's work shows, even when the examples are drawn from a nice
distribution (i.e. don't have the ∃j∀i,xj(i)=0
property), gradient descent can take an exponential (in n) number of
steps to learn some h∈H. In my view, the star of the
show is Theorem 1, which establishes when the gradient of the loss
is similar for all elements of H, the gradient carries
little information about what element of H should actually
be learned. This small signal is lost in noise from finite
precision arithmetic and sampling.
Gradient Descent
In the typical machine learning setting, for some target function
h:Rn→Rm and neural network
architecture pw parameterized by weights w∈R∣w∣
we'd like to compute
wminFh(w):=Ex(21∥h(x)−pw(x)∥2),
where Fh(w) is the expected loss given a choice of w.
One approach is to use a variation of gradient descent. This starts
by selecting an initial value for w, call it w(0), then
procedures to iteratively update the weights according to the formula
w(i+1):=w(i)−η(∂w∂Fh)(w(i)),
where η is the learning rate. Intuitively, this works because
(∂w∂Fh)(w(i)) is an element of
R∣w∣ pointing in the direction of steepest increase of
Ex∥h(x)−pw(x)∥2, i.e. the loss. By inverting and
subtracting this value from wi we move the weights in a direction
that decreases the loss.
In practice, computing Ex∥h(x)−pw(x)∥2, and in turn
∂w∂Fh, is computationally infeasible, as
x's distribution is unknown. As such, the standard approach is to
sample x1,x2,…,xt and approximate Fh(w) as
Fh(w)≈Σi∥h(xi)−pw(xi)∥2.
∂w∂Fh can be approximated in turn, and
gradient descent run using the approximation.
Approximate Gradient Descent
We'll model this approximation with two definitions, approximate gradient oracles capture the error involved in computing
∂w∂Fh and approximate gradient-based methods use these approximations.
An Approximate Gradient Oracle is a function
OFh,ϵ:R∣w∣→R∣w∣
satisfying
∀w,∣OFh,ϵ(w)−∂w∂Fh(w)∣≤ϵ.
An Approximate Gradient-Based Method is an algorithm that
generates an initial guess w(0), then decides w(i+1)
based on responses from an approximate gradient oracle.
w(i+1)=f(w(0),{OFh,ϵ(w(i))∣i<i+1})
Parity Hardness
Now, consider a family of functions H={h1,h2,…}
and the variance of the gradient at w with respect to any
h∈H.
To show parity is difficult to learn, we'll show when H
is the family of parity functions, this variance is exponentially
small w.r.t. the length of the parity function's inputs. In turn, an
adversarial approximate gradient oracle can repeatedly return
Eh′((∂w∂Fh′)(w)) instead of
(∂w∂Fh)(w) while staying within its
ϵ error tolerance. Because
Eh′((∂w∂Fh′)(w)) is
independent of the h∈H being learned, an approximate gradient-based method using this adversarial oracle can converge to
a value independent of the target function, h, unless it takes an
exponentially large number of steps.
Theorem 1 (DSHLNN Theorem 10). For some family of functions H, if
∀w,Varh∈H(∂w∂Fh)(w)≤ϵ3
then for any approximate gradient-based method learning
h∈H, there is a run such that the value of
w(⌊ϵ−1⌋) is independent of h.
Proof. By Chebyshev's inequality and the hypothesis, we have
Because Eh((∂w∂Fh)(w)) is
independent of what h is being learned, the inequality above
bounds the likelihood OFh,ϵ(w) is dependent on
h.
Now, for any approximate gradient-based method learning
h∈H, w(0) is independent of h, as nothing has
been sampled from the gradient when it is chosen. As
w(1)=f(w(0),OFh,ϵ(w(0))),
evidently, w(1) is dependent on the h being learned if
OFh,ϵ(w(0)) is, and, per the inequality above, the
likelihood of this is ≤ϵ. Repeating this argument, let
A(i) be the event OFh,ϵ(w(i)) is dependent on
h. We have P(A(i))≤ϵ, so by the union bound
P(i=1⋁IA(i))≤Σi=1IP(A(i))≤Iϵ.
If P(⋁i=1IA(i))<1, then there is an I step
run of our gradient-based method where w(I) is independent of the
target function, h. Solving for I using the equation above gives
the desired bound: If I<1/ϵ, then there is a run of the
gradient-based method where w(⌊ϵ−1⌋) is
independent of h (I am simplifying somewhat here because for the
case we're interested ϵ−1 will not be an integer and
flooring gives the strict < inequality we want, but if you're
feeling picky I=⌈ϵ−1⌉−1 will do).
■
Lemma 1.
Σx∈[n]dΠi=1dfi(xi)=Πi=1dΣx∈[n]fi(x)
Proof. Shamir wordlessly invokes this, but it took me several hours
on an airplane and help from ChatGPT to see. By induction on d. When
d=2,
Proof. Notice g(a):=Ex∥f(x)−a∥2 is (strictly)
convex, so the value of a satisfying (∂a∂g)(a)=0 is a global minimum. Computing this derivative and solving
gives a=Ex(f(x)), as desired. This corresponds to the
intuitive idea that the mean minimizes the expected difference between
it and all other values.
■
Lemma 4 (Bessel's inequality). For a Hilbert space H and finite set {ei∣i∈I} satisfying
⟨ei,ej⟩={10(i=j)(i=j),
i.e. {ei∣i∈I} is an orthonormal family, we have
∀x∈H,Σi∈I∥⟨x,ei⟩∥2≤∥x∥2.
Proof. Intuitively, this says that projecting x onto an
orthonormal basis won't increase its size.Let ai be ⟨x,ei⟩ and y be Σi∈Iaiei, i.e. y is the
projection of x onto the basis formed by {ei∣i∈I}, and
ai is the component of x lying on the ith element of the
basis. Now, let the residual r be x−y and note r is
perpendicular to each ei
Theorem 2 (FGBDL Theorem 1).
Let H be the family of parity functions and let pw
and Fh be a neural network and MSE loss function, as before. If
pw satisfies
∀w,Ex∥(∂w∂pw)(x)∥2≤G(w)2
for some scalar G(w), then
Varh((∂w∂Fh)(w))≤∣H∣G(w)2.
Proof. First, invoking Lemma 3 with
a=Ex(pw(x)(∂w∂pw)(x)), we have