Skip to content

Commit 21173cc

Browse files
vlad17Vladimir Feinberg
authored andcommitted
Add initial LOBPCG top-k eigenvalue solver (#3112)
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop. For details, see jax.experimental.linalg.standard_lobpcg documentation.
1 parent bc877fa commit 21173cc

File tree

3 files changed

+562
-0
lines changed

3 files changed

+562
-0
lines changed

‎jax/experimental/sparse/__init__.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,5 @@
259259
sparsify as sparsify,
260260
SparseTracer as SparseTracer,
261261
)
262+
263+
from jax.experimental.sparse import linalg

‎jax/experimental/sparse/linalg.py‎

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Sparse linear algebra routines."""
16+
17+
from typing import Union
18+
import functools
19+
20+
import jax
21+
import jax.numpy as jnp
22+
23+
24+
def lobpcg_standard(
25+
A: jnp.ndarray,
26+
X: jnp.ndarray,
27+
m: int = 100,
28+
tol: Union[jnp.ndarray, float, None] = None):
29+
"""Compute the top-k standard eigenvalues using the LOBPCG routine.
30+
31+
LOBPCG [1] stands for Locally Optimal Block Preconditioned Conjugate Gradient.
32+
The method enables finding top-k eigenvectors in an accelerator-friendly
33+
manner.
34+
35+
This initial experimental version has several caveats.
36+
37+
- Only the standard eigenvalue problem `A U = lambda U` is supported,
38+
general eigenvalues are not.
39+
- Gradient code is not available.
40+
- f64 will only work where jnp.linalg.eigh is supported for that type.
41+
- Finding the smallest eigenvectors is not yet supported. As a result,
42+
we don't yet support preconditioning, which is mostly needed for this
43+
case.
44+
45+
The implementation is based on [2] and [3]; however, we deviate from these
46+
sources in several ways to improve robustness or facilitate implementation:
47+
48+
- Despite increased iteration cost, we always maintain an orthonormal basis
49+
for the block search directions.
50+
- We change the convergence criterion; see the `tol` argument.
51+
- Soft locking is not implemented.
52+
53+
[1] http://ccm.ucdenver.edu/reports/rep149.pdf
54+
[2] https://arxiv.org/abs/1704.07458
55+
[3] https://arxiv.org/abs/0705.2626
56+
57+
Args:
58+
A : An `(n, n)` array representing a square matrix.
59+
X : An `(n, k)` array representing the initial search directions for the `k`
60+
desired top eigenvectors. This need not be orthogonal, but must be
61+
linearly independent.
62+
m : Maximum integer iteration count; LOBPCG will only ever explore (a
63+
subspace of) the Krylov basis `{X, A X, A^2 X, ..., A^m X}`.
64+
tol : A float convergence tolerance; an eigenpair `(lambda, v)` is converged
65+
when its residual L2 norm `r = |A v - lambda v|` is below
66+
`tol * 10 * n * (lambda + |A v|)`, which
67+
roughly estimates the worst-case floating point error for an ideal
68+
eigenvector. If all `k` eigenvectors satisfy the tolerance
69+
comparison, then LOBPCG exits early. If left as None, then this is set
70+
to the float epsilon of `A.dtype`.
71+
72+
Returns:
73+
`theta, U, i [, diagnostics]`, where `theta` is a `(k,)` array
74+
of eigenvalues, `U` is a `(n, k)` array of eigenvectors, `i` is the
75+
number of iterations performed, and `diagnostics` is a dictionary with debug
76+
information, which is only returned if `debug` is set to true.
77+
78+
Raises:
79+
ValueError : if `A,X` dtypes or `n` dimensions do not match, or `k` is too
80+
large (only `k * 5 < n` supported), or `k == 0`.
81+
"""
82+
return _lobpcg_standard(A, X, m, tol, debug=False)
83+
84+
@functools.partial(jax.jit, static_argnames=['m', 'debug'])
85+
def _lobpcg_standard(
86+
A: jnp.ndarray,
87+
X: jnp.ndarray,
88+
m: int,
89+
tol: Union[jnp.ndarray, float, None],
90+
debug: bool = False):
91+
"""Computes lobpcg_standard(), possibly with debug diagnostics."""
92+
93+
# TODO(vladf): support mixed_precision flag, which allows f64 Rayleigh-Ritz
94+
# with f32 inputs.
95+
mixed_precision = False
96+
97+
n, k = X.shape
98+
dt = X.dtype
99+
100+
_check_inputs(A, X)
101+
102+
if tol is None:
103+
tol = jnp.finfo(dt).eps
104+
105+
X = _svqb(X, mixed_precision)
106+
P = _extend_basis(X, X.shape[1])
107+
108+
# We maintain X, our current list of best eigenvectors,
109+
# P, our search direction, and
110+
# R, our residuals, in a large joint array XPR, column-stacked, so (n, 3*k).
111+
112+
AX = _mm(A, X)
113+
theta = jnp.sum(X * AX, axis=0, keepdims=True)
114+
R = AX - theta * X
115+
116+
def cond(state):
117+
i, _X, _P, _R, converged, _ = state
118+
return jnp.logical_and(i < m, converged < k)
119+
120+
def body(state):
121+
i, X, P, R, _, theta = state
122+
# Invariants: X, P, R kept orthonormal
123+
# Some R, P columns may be 0, but not X.
124+
125+
# TODO(vladf): support preconditioning for bottom-k eigenvectors
126+
# if M is not None:
127+
# R = M(R)
128+
129+
# residual basis selection
130+
R = _project_out(jnp.concatenate((X, P), axis=1), R, mixed_precision)
131+
XPR = jnp.concatenate((X, P, R), axis=1)
132+
133+
# Projected eigensolve.
134+
theta, Q = _rayleigh_ritz_orth(A, XPR, mixed_precision)
135+
136+
# Eigenvector X extraction
137+
B = Q[:, :k]
138+
normB = jnp.linalg.norm(B, ord=2, axis=0, keepdims=True)
139+
B /= normB
140+
X = _mm(XPR, B)
141+
normX = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True)
142+
X /= normX
143+
144+
# Difference terms P extraction
145+
#
146+
# In next step of LOBPCG, naively, we'd set
147+
# P = S[:, k:] @ Q[k:, :k] to achieve span(X, P) == span(X, previous X)
148+
# (this is not obvious, see section 4 of [1]).
149+
#
150+
# Instead we orthogonalize concat(0, Q[k:, :k]) against Q[:, :k]
151+
# in the standard basis before mapping with XPR. Since XPR is itself
152+
# orthonormal, the resulting directions are themselves orthonormalized.
153+
#
154+
# [2] leverages Q's existing orthogonality to derive
155+
# an analytic expression for this value based on the quadrant Q[:k,k:]
156+
# (see section 4.2 of [2]).
157+
q, _ = jnp.linalg.qr(Q[:k, k:].T)
158+
diff_rayleigh_ortho = _mm(Q[:, k:], q)
159+
P = _mm(XPR, diff_rayleigh_ortho)
160+
normP = jnp.linalg.norm(P, ord=2, axis=0, keepdims=True)
161+
P /= jnp.where(normP == 0, 1.0, normP)
162+
163+
# Compute new residuals.
164+
AX = _mm(A, X)
165+
R = AX - theta[jnp.newaxis, :k] * X
166+
resid_norms = jnp.linalg.norm(R, ord=2, axis=0)
167+
168+
# I tried many variants of hard and soft locking [3]. All of them seemed
169+
# to worsen performance relative to no locking.
170+
#
171+
# Further, I found a more expermental convergence formula compared to what
172+
# is suggested in the literature, loosely based on floating-point
173+
# expectations.
174+
#
175+
# [2] discusses various strategies for this in Sec 5.3. The solution
176+
# they end up with, which estimates operator norm |A| via Gaussian
177+
# products, was too crude in practice (and overly-lax). The Gaussian
178+
# approximation seems like an estimate of the average eigenvalue.
179+
#
180+
# Instead, we test convergence via self-consistency of the eigenpair
181+
# i.e., the residual norm |r| should be small, relative to the floating
182+
# point error we'd expect from computing just the residuals given
183+
# candidate vectors.
184+
#
185+
# sqrt(n) - random walk error from AX multiply
186+
reltol = jnp.linalg.norm(AX, ord=2, axis=0) + theta[:k]
187+
reltol *= n
188+
# Allow some margin for a few element-wise operations.
189+
reltol *= 10
190+
res_converged = resid_norms < tol * reltol
191+
converged = jnp.sum(res_converged)
192+
193+
new_state = i + 1, X, P, R, converged, theta[jnp.newaxis, :k]
194+
if debug:
195+
diagnostics = _generate_diagnostics(
196+
XPR, X, P, R, theta, converged, resid_norms / reltol)
197+
new_state = (new_state, diagnostics)
198+
return new_state
199+
200+
converged = 0
201+
state = (0, X, P, R, converged, theta)
202+
if debug:
203+
state, diagnostics = jax.lax.scan(
204+
lambda state, _: body(state), state, xs=None, length=m)
205+
else:
206+
state = jax.lax.while_loop(cond, body, state)
207+
i, X, _P, _R, _converged, theta = state
208+
209+
if debug:
210+
return theta[0, :], X, i, diagnostics
211+
return theta[0, :], X, i
212+
213+
214+
def _check_inputs(A, X):
215+
n, k = X.shape
216+
dt = X.dtype
217+
218+
if k == 0:
219+
raise ValueError(f'must have search dim > 0, got {k}')
220+
221+
if A.dtype != dt:
222+
raise ValueError(f'A, X must have same dtypes (were {A.dtype}, {dt})')
223+
224+
if A.shape != (n, n):
225+
raise ValueError(f'A must be ({n}, {n}) matrix A, got {A.shape}')
226+
227+
if k * 5 >= n:
228+
raise ValueError(f'expected search dim * 5 < matrix dim (got {k * 5}, {n})')
229+
230+
231+
def _mm(a, b, precision=jax.lax.Precision.HIGHEST):
232+
return jax.lax.dot(a, b, (precision, precision))
233+
234+
def _generate_diagnostics(prev_XPR, X, P, R, theta, converged, adj_resid):
235+
k = X.shape[1]
236+
assert X.shape == P.shape
237+
238+
diagdiag = lambda x: jnp.diag(jnp.diag(x))
239+
abserr = lambda x: jnp.abs(x).sum() / (k ** 2)
240+
241+
XTX = _mm(X.T, X)
242+
DX = diagdiag(XTX)
243+
orthX = abserr(XTX - DX)
244+
245+
PTP = _mm(P.T, P)
246+
DP = diagdiag(PTP)
247+
orthP = abserr(PTP - DP)
248+
249+
PX = abserr(X.T @ P)
250+
251+
prev_basis = prev_XPR.shape[1] - jnp.sum(jnp.all(prev_XPR == 0.0, axis=0))
252+
253+
return {
254+
'basis rank': prev_basis,
255+
'X zeros': jnp.sum(jnp.all(X == 0.0, axis=0)),
256+
'P zeros': jnp.sum(jnp.all(P == 0.0, axis=0)),
257+
'lambda history': theta[:k],
258+
'residual history': jnp.linalg.norm(R, axis=0, ord=2),
259+
'converged': converged,
260+
'adjusted residual max': jnp.max(adj_resid),
261+
'adjusted residual p50': jnp.median(adj_resid),
262+
'adjusted residual min': jnp.min(adj_resid),
263+
'X orth': orthX,
264+
'P orth': orthP,
265+
'P.X': PX}
266+
267+
def _eigh_possibly_mixed(A, mixed_precision):
268+
assert not mixed_precision, 'mixed precision not yet supported'
269+
w, V = jnp.linalg.eigh(A)
270+
return w[::-1], V[:, ::-1]
271+
272+
273+
def _svqb(X, mixed_precision):
274+
# https://sdm.lbl.gov/~kewu/ps/45577.html
275+
276+
norms = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True)
277+
X /= jnp.where(norms == 0, 1.0, norms)
278+
279+
inner = _mm(X.T, X)
280+
281+
w, V = _eigh_possibly_mixed(inner, mixed_precision)
282+
283+
tau = jnp.finfo(X.dtype).eps * w[0]
284+
padded = jnp.maximum(w, tau)
285+
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5)
286+
scaledV = V * sqrted[jnp.newaxis, :]
287+
orthoX = _mm(X, scaledV)
288+
289+
keep = ((w > tau) * (jnp.diag(inner) > 0.0))[jnp.newaxis, :]
290+
orthoX *= keep
291+
norms = jnp.linalg.norm(orthoX, ord=2, axis=0, keepdims=True)
292+
keep *= norms > 0.0
293+
orthoX /= jnp.where(keep, norms, 1.0)
294+
return orthoX
295+
296+
297+
def _project_out(basis, U, mixed_precision):
298+
# "twice is enough" from shoyer's reference:
299+
# http://slepc.upv.es/documentation/reports/str1.pdf
300+
301+
for _ in range(2):
302+
U -= _mm(basis, _mm(basis.T, U))
303+
for _ in range(2):
304+
U = _svqb(U, mixed_precision)
305+
306+
return U
307+
308+
309+
def _rayleigh_ritz_orth(A, S, mixed_precision):
310+
# Classical Rayleigh-Ritz returns w, V satisfying
311+
# (1) S.T A S @ V ~= w * V
312+
# such that (2) V is (S.T S)-orthonormal.
313+
# https://www.netlib.org/lapack/lug/node54.html
314+
#
315+
# Usually it requires solving the complicated standard eigensystem
316+
# U^-T S^T A S U^-1 @ Q = w * Q and then backsolving V = U^-1 Q,
317+
# but if S is standard orthonormal then we just need to find
318+
# eigenvalues of S.T A S.
319+
320+
SAS = _mm(S.T, _mm(A, S))
321+
322+
# Solve the projected subsytem.
323+
# If we could tell to eigh to stop after first k, we would.
324+
return _eigh_possibly_mixed(SAS, mixed_precision)
325+
326+
327+
def _extend_basis(X, m):
328+
# Use a block householder reflector to generate orthogonal extension
329+
# to X. There's nothing too special about this, and we could choose
330+
# any random extension to X's basis, but this is a deterministic choice.
331+
#
332+
# https://epubs.siam.org/doi/abs/10.1137/0725014
333+
# https://www.jstage.jst.go.jp/article/ipsjdc/2/0/2_0_298/_article
334+
n, k = X.shape
335+
Xupper, Xlower = jnp.split(X, [k], axis=0)
336+
u, s, vt = jnp.linalg.svd(Xupper)
337+
y = jnp.concatenate([Xupper + u @ vt, Xlower], axis=0)
338+
other = jnp.concatenate(
339+
[jnp.eye(m, dtype=X.dtype),
340+
jnp.zeros((n - k - m, m), dtype=X.dtype)], axis=0)
341+
w = _mm(y, vt.T * ((2 * (1 + s)) ** (-1/2))[jnp.newaxis, :])
342+
h = -2 * jnp.linalg.multi_dot(
343+
[w, w[k:, :].T, other], precision=jax.lax.Precision.HIGHEST)
344+
return h.at[k:].add(other)

0 commit comments

Comments
 (0)