|
| 1 | +# Copyright 2022 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Sparse linear algebra routines.""" |
| 16 | + |
| 17 | +from typing import Union |
| 18 | +import functools |
| 19 | + |
| 20 | +import jax |
| 21 | +import jax.numpy as jnp |
| 22 | + |
| 23 | + |
| 24 | +def lobpcg_standard( |
| 25 | + A: jnp.ndarray, |
| 26 | + X: jnp.ndarray, |
| 27 | + m: int = 100, |
| 28 | + tol: Union[jnp.ndarray, float, None] = None): |
| 29 | + """Compute the top-k standard eigenvalues using the LOBPCG routine. |
| 30 | +
|
| 31 | + LOBPCG [1] stands for Locally Optimal Block Preconditioned Conjugate Gradient. |
| 32 | + The method enables finding top-k eigenvectors in an accelerator-friendly |
| 33 | + manner. |
| 34 | +
|
| 35 | + This initial experimental version has several caveats. |
| 36 | +
|
| 37 | + - Only the standard eigenvalue problem `A U = lambda U` is supported, |
| 38 | + general eigenvalues are not. |
| 39 | + - Gradient code is not available. |
| 40 | + - f64 will only work where jnp.linalg.eigh is supported for that type. |
| 41 | + - Finding the smallest eigenvectors is not yet supported. As a result, |
| 42 | + we don't yet support preconditioning, which is mostly needed for this |
| 43 | + case. |
| 44 | +
|
| 45 | + The implementation is based on [2] and [3]; however, we deviate from these |
| 46 | + sources in several ways to improve robustness or facilitate implementation: |
| 47 | +
|
| 48 | + - Despite increased iteration cost, we always maintain an orthonormal basis |
| 49 | + for the block search directions. |
| 50 | + - We change the convergence criterion; see the `tol` argument. |
| 51 | + - Soft locking is not implemented. |
| 52 | +
|
| 53 | + [1] http://ccm.ucdenver.edu/reports/rep149.pdf |
| 54 | + [2] https://arxiv.org/abs/1704.07458 |
| 55 | + [3] https://arxiv.org/abs/0705.2626 |
| 56 | +
|
| 57 | + Args: |
| 58 | + A : An `(n, n)` array representing a square matrix. |
| 59 | + X : An `(n, k)` array representing the initial search directions for the `k` |
| 60 | + desired top eigenvectors. This need not be orthogonal, but must be |
| 61 | + linearly independent. |
| 62 | + m : Maximum integer iteration count; LOBPCG will only ever explore (a |
| 63 | + subspace of) the Krylov basis `{X, A X, A^2 X, ..., A^m X}`. |
| 64 | + tol : A float convergence tolerance; an eigenpair `(lambda, v)` is converged |
| 65 | + when its residual L2 norm `r = |A v - lambda v|` is below |
| 66 | + `tol * 10 * n * (lambda + |A v|)`, which |
| 67 | + roughly estimates the worst-case floating point error for an ideal |
| 68 | + eigenvector. If all `k` eigenvectors satisfy the tolerance |
| 69 | + comparison, then LOBPCG exits early. If left as None, then this is set |
| 70 | + to the float epsilon of `A.dtype`. |
| 71 | +
|
| 72 | + Returns: |
| 73 | + `theta, U, i [, diagnostics]`, where `theta` is a `(k,)` array |
| 74 | + of eigenvalues, `U` is a `(n, k)` array of eigenvectors, `i` is the |
| 75 | + number of iterations performed, and `diagnostics` is a dictionary with debug |
| 76 | + information, which is only returned if `debug` is set to true. |
| 77 | +
|
| 78 | + Raises: |
| 79 | + ValueError : if `A,X` dtypes or `n` dimensions do not match, or `k` is too |
| 80 | + large (only `k * 5 < n` supported), or `k == 0`. |
| 81 | + """ |
| 82 | + return _lobpcg_standard(A, X, m, tol, debug=False) |
| 83 | + |
| 84 | +@functools.partial(jax.jit, static_argnames=['m', 'debug']) |
| 85 | +def _lobpcg_standard( |
| 86 | + A: jnp.ndarray, |
| 87 | + X: jnp.ndarray, |
| 88 | + m: int, |
| 89 | + tol: Union[jnp.ndarray, float, None], |
| 90 | + debug: bool = False): |
| 91 | + """Computes lobpcg_standard(), possibly with debug diagnostics.""" |
| 92 | + |
| 93 | + # TODO(vladf): support mixed_precision flag, which allows f64 Rayleigh-Ritz |
| 94 | + # with f32 inputs. |
| 95 | + mixed_precision = False |
| 96 | + |
| 97 | + n, k = X.shape |
| 98 | + dt = X.dtype |
| 99 | + |
| 100 | + _check_inputs(A, X) |
| 101 | + |
| 102 | + if tol is None: |
| 103 | + tol = jnp.finfo(dt).eps |
| 104 | + |
| 105 | + X = _svqb(X, mixed_precision) |
| 106 | + P = _extend_basis(X, X.shape[1]) |
| 107 | + |
| 108 | + # We maintain X, our current list of best eigenvectors, |
| 109 | + # P, our search direction, and |
| 110 | + # R, our residuals, in a large joint array XPR, column-stacked, so (n, 3*k). |
| 111 | + |
| 112 | + AX = _mm(A, X) |
| 113 | + theta = jnp.sum(X * AX, axis=0, keepdims=True) |
| 114 | + R = AX - theta * X |
| 115 | + |
| 116 | + def cond(state): |
| 117 | + i, _X, _P, _R, converged, _ = state |
| 118 | + return jnp.logical_and(i < m, converged < k) |
| 119 | + |
| 120 | + def body(state): |
| 121 | + i, X, P, R, _, theta = state |
| 122 | + # Invariants: X, P, R kept orthonormal |
| 123 | + # Some R, P columns may be 0, but not X. |
| 124 | + |
| 125 | + # TODO(vladf): support preconditioning for bottom-k eigenvectors |
| 126 | + # if M is not None: |
| 127 | + # R = M(R) |
| 128 | + |
| 129 | + # residual basis selection |
| 130 | + R = _project_out(jnp.concatenate((X, P), axis=1), R, mixed_precision) |
| 131 | + XPR = jnp.concatenate((X, P, R), axis=1) |
| 132 | + |
| 133 | + # Projected eigensolve. |
| 134 | + theta, Q = _rayleigh_ritz_orth(A, XPR, mixed_precision) |
| 135 | + |
| 136 | + # Eigenvector X extraction |
| 137 | + B = Q[:, :k] |
| 138 | + normB = jnp.linalg.norm(B, ord=2, axis=0, keepdims=True) |
| 139 | + B /= normB |
| 140 | + X = _mm(XPR, B) |
| 141 | + normX = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True) |
| 142 | + X /= normX |
| 143 | + |
| 144 | + # Difference terms P extraction |
| 145 | + # |
| 146 | + # In next step of LOBPCG, naively, we'd set |
| 147 | + # P = S[:, k:] @ Q[k:, :k] to achieve span(X, P) == span(X, previous X) |
| 148 | + # (this is not obvious, see section 4 of [1]). |
| 149 | + # |
| 150 | + # Instead we orthogonalize concat(0, Q[k:, :k]) against Q[:, :k] |
| 151 | + # in the standard basis before mapping with XPR. Since XPR is itself |
| 152 | + # orthonormal, the resulting directions are themselves orthonormalized. |
| 153 | + # |
| 154 | + # [2] leverages Q's existing orthogonality to derive |
| 155 | + # an analytic expression for this value based on the quadrant Q[:k,k:] |
| 156 | + # (see section 4.2 of [2]). |
| 157 | + q, _ = jnp.linalg.qr(Q[:k, k:].T) |
| 158 | + diff_rayleigh_ortho = _mm(Q[:, k:], q) |
| 159 | + P = _mm(XPR, diff_rayleigh_ortho) |
| 160 | + normP = jnp.linalg.norm(P, ord=2, axis=0, keepdims=True) |
| 161 | + P /= jnp.where(normP == 0, 1.0, normP) |
| 162 | + |
| 163 | + # Compute new residuals. |
| 164 | + AX = _mm(A, X) |
| 165 | + R = AX - theta[jnp.newaxis, :k] * X |
| 166 | + resid_norms = jnp.linalg.norm(R, ord=2, axis=0) |
| 167 | + |
| 168 | + # I tried many variants of hard and soft locking [3]. All of them seemed |
| 169 | + # to worsen performance relative to no locking. |
| 170 | + # |
| 171 | + # Further, I found a more expermental convergence formula compared to what |
| 172 | + # is suggested in the literature, loosely based on floating-point |
| 173 | + # expectations. |
| 174 | + # |
| 175 | + # [2] discusses various strategies for this in Sec 5.3. The solution |
| 176 | + # they end up with, which estimates operator norm |A| via Gaussian |
| 177 | + # products, was too crude in practice (and overly-lax). The Gaussian |
| 178 | + # approximation seems like an estimate of the average eigenvalue. |
| 179 | + # |
| 180 | + # Instead, we test convergence via self-consistency of the eigenpair |
| 181 | + # i.e., the residual norm |r| should be small, relative to the floating |
| 182 | + # point error we'd expect from computing just the residuals given |
| 183 | + # candidate vectors. |
| 184 | + # |
| 185 | + # sqrt(n) - random walk error from AX multiply |
| 186 | + reltol = jnp.linalg.norm(AX, ord=2, axis=0) + theta[:k] |
| 187 | + reltol *= n |
| 188 | + # Allow some margin for a few element-wise operations. |
| 189 | + reltol *= 10 |
| 190 | + res_converged = resid_norms < tol * reltol |
| 191 | + converged = jnp.sum(res_converged) |
| 192 | + |
| 193 | + new_state = i + 1, X, P, R, converged, theta[jnp.newaxis, :k] |
| 194 | + if debug: |
| 195 | + diagnostics = _generate_diagnostics( |
| 196 | + XPR, X, P, R, theta, converged, resid_norms / reltol) |
| 197 | + new_state = (new_state, diagnostics) |
| 198 | + return new_state |
| 199 | + |
| 200 | + converged = 0 |
| 201 | + state = (0, X, P, R, converged, theta) |
| 202 | + if debug: |
| 203 | + state, diagnostics = jax.lax.scan( |
| 204 | + lambda state, _: body(state), state, xs=None, length=m) |
| 205 | + else: |
| 206 | + state = jax.lax.while_loop(cond, body, state) |
| 207 | + i, X, _P, _R, _converged, theta = state |
| 208 | + |
| 209 | + if debug: |
| 210 | + return theta[0, :], X, i, diagnostics |
| 211 | + return theta[0, :], X, i |
| 212 | + |
| 213 | + |
| 214 | +def _check_inputs(A, X): |
| 215 | + n, k = X.shape |
| 216 | + dt = X.dtype |
| 217 | + |
| 218 | + if k == 0: |
| 219 | + raise ValueError(f'must have search dim > 0, got {k}') |
| 220 | + |
| 221 | + if A.dtype != dt: |
| 222 | + raise ValueError(f'A, X must have same dtypes (were {A.dtype}, {dt})') |
| 223 | + |
| 224 | + if A.shape != (n, n): |
| 225 | + raise ValueError(f'A must be ({n}, {n}) matrix A, got {A.shape}') |
| 226 | + |
| 227 | + if k * 5 >= n: |
| 228 | + raise ValueError(f'expected search dim * 5 < matrix dim (got {k * 5}, {n})') |
| 229 | + |
| 230 | + |
| 231 | +def _mm(a, b, precision=jax.lax.Precision.HIGHEST): |
| 232 | + return jax.lax.dot(a, b, (precision, precision)) |
| 233 | + |
| 234 | +def _generate_diagnostics(prev_XPR, X, P, R, theta, converged, adj_resid): |
| 235 | + k = X.shape[1] |
| 236 | + assert X.shape == P.shape |
| 237 | + |
| 238 | + diagdiag = lambda x: jnp.diag(jnp.diag(x)) |
| 239 | + abserr = lambda x: jnp.abs(x).sum() / (k ** 2) |
| 240 | + |
| 241 | + XTX = _mm(X.T, X) |
| 242 | + DX = diagdiag(XTX) |
| 243 | + orthX = abserr(XTX - DX) |
| 244 | + |
| 245 | + PTP = _mm(P.T, P) |
| 246 | + DP = diagdiag(PTP) |
| 247 | + orthP = abserr(PTP - DP) |
| 248 | + |
| 249 | + PX = abserr(X.T @ P) |
| 250 | + |
| 251 | + prev_basis = prev_XPR.shape[1] - jnp.sum(jnp.all(prev_XPR == 0.0, axis=0)) |
| 252 | + |
| 253 | + return { |
| 254 | + 'basis rank': prev_basis, |
| 255 | + 'X zeros': jnp.sum(jnp.all(X == 0.0, axis=0)), |
| 256 | + 'P zeros': jnp.sum(jnp.all(P == 0.0, axis=0)), |
| 257 | + 'lambda history': theta[:k], |
| 258 | + 'residual history': jnp.linalg.norm(R, axis=0, ord=2), |
| 259 | + 'converged': converged, |
| 260 | + 'adjusted residual max': jnp.max(adj_resid), |
| 261 | + 'adjusted residual p50': jnp.median(adj_resid), |
| 262 | + 'adjusted residual min': jnp.min(adj_resid), |
| 263 | + 'X orth': orthX, |
| 264 | + 'P orth': orthP, |
| 265 | + 'P.X': PX} |
| 266 | + |
| 267 | +def _eigh_possibly_mixed(A, mixed_precision): |
| 268 | + assert not mixed_precision, 'mixed precision not yet supported' |
| 269 | + w, V = jnp.linalg.eigh(A) |
| 270 | + return w[::-1], V[:, ::-1] |
| 271 | + |
| 272 | + |
| 273 | +def _svqb(X, mixed_precision): |
| 274 | + # https://sdm.lbl.gov/~kewu/ps/45577.html |
| 275 | + |
| 276 | + norms = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True) |
| 277 | + X /= jnp.where(norms == 0, 1.0, norms) |
| 278 | + |
| 279 | + inner = _mm(X.T, X) |
| 280 | + |
| 281 | + w, V = _eigh_possibly_mixed(inner, mixed_precision) |
| 282 | + |
| 283 | + tau = jnp.finfo(X.dtype).eps * w[0] |
| 284 | + padded = jnp.maximum(w, tau) |
| 285 | + sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5) |
| 286 | + scaledV = V * sqrted[jnp.newaxis, :] |
| 287 | + orthoX = _mm(X, scaledV) |
| 288 | + |
| 289 | + keep = ((w > tau) * (jnp.diag(inner) > 0.0))[jnp.newaxis, :] |
| 290 | + orthoX *= keep |
| 291 | + norms = jnp.linalg.norm(orthoX, ord=2, axis=0, keepdims=True) |
| 292 | + keep *= norms > 0.0 |
| 293 | + orthoX /= jnp.where(keep, norms, 1.0) |
| 294 | + return orthoX |
| 295 | + |
| 296 | + |
| 297 | +def _project_out(basis, U, mixed_precision): |
| 298 | + # "twice is enough" from shoyer's reference: |
| 299 | + # http://slepc.upv.es/documentation/reports/str1.pdf |
| 300 | + |
| 301 | + for _ in range(2): |
| 302 | + U -= _mm(basis, _mm(basis.T, U)) |
| 303 | + for _ in range(2): |
| 304 | + U = _svqb(U, mixed_precision) |
| 305 | + |
| 306 | + return U |
| 307 | + |
| 308 | + |
| 309 | +def _rayleigh_ritz_orth(A, S, mixed_precision): |
| 310 | + # Classical Rayleigh-Ritz returns w, V satisfying |
| 311 | + # (1) S.T A S @ V ~= w * V |
| 312 | + # such that (2) V is (S.T S)-orthonormal. |
| 313 | + # https://www.netlib.org/lapack/lug/node54.html |
| 314 | + # |
| 315 | + # Usually it requires solving the complicated standard eigensystem |
| 316 | + # U^-T S^T A S U^-1 @ Q = w * Q and then backsolving V = U^-1 Q, |
| 317 | + # but if S is standard orthonormal then we just need to find |
| 318 | + # eigenvalues of S.T A S. |
| 319 | + |
| 320 | + SAS = _mm(S.T, _mm(A, S)) |
| 321 | + |
| 322 | + # Solve the projected subsytem. |
| 323 | + # If we could tell to eigh to stop after first k, we would. |
| 324 | + return _eigh_possibly_mixed(SAS, mixed_precision) |
| 325 | + |
| 326 | + |
| 327 | +def _extend_basis(X, m): |
| 328 | + # Use a block householder reflector to generate orthogonal extension |
| 329 | + # to X. There's nothing too special about this, and we could choose |
| 330 | + # any random extension to X's basis, but this is a deterministic choice. |
| 331 | + # |
| 332 | + # https://epubs.siam.org/doi/abs/10.1137/0725014 |
| 333 | + # https://www.jstage.jst.go.jp/article/ipsjdc/2/0/2_0_298/_article |
| 334 | + n, k = X.shape |
| 335 | + Xupper, Xlower = jnp.split(X, [k], axis=0) |
| 336 | + u, s, vt = jnp.linalg.svd(Xupper) |
| 337 | + y = jnp.concatenate([Xupper + u @ vt, Xlower], axis=0) |
| 338 | + other = jnp.concatenate( |
| 339 | + [jnp.eye(m, dtype=X.dtype), |
| 340 | + jnp.zeros((n - k - m, m), dtype=X.dtype)], axis=0) |
| 341 | + w = _mm(y, vt.T * ((2 * (1 + s)) ** (-1/2))[jnp.newaxis, :]) |
| 342 | + h = -2 * jnp.linalg.multi_dot( |
| 343 | + [w, w[k:, :].T, other], precision=jax.lax.Precision.HIGHEST) |
| 344 | + return h.at[k:].add(other) |
0 commit comments