Skip to content

Conversation

@vlad17
Copy link
Contributor

@vlad17 vlad17 commented Jun 3, 2022

This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.

@google-cla
Copy link

google-cla bot commented Jun 3, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@vlad17 vlad17 force-pushed the topk branch 3 times, most recently from 3d3cfc5 to 18b1d8e Compare June 3, 2022 02:36
@vlad17
Copy link
Contributor Author

vlad17 commented Jun 3, 2022

@shoyer @tabakg @lobpcg would any of you be interested in reviewing?

@vlad17 vlad17 force-pushed the topk branch 2 times, most recently from 21173cc to 9e0ed50 Compare June 3, 2022 02:40
@jakevdp jakevdp requested a review from shoyer June 3, 2022 16:11
'jax_traceback_filtering': 'off',
}

# TODO(vladf): add f64 tests just to verify it compiles?
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would be good -- you could probably check if it's running on CPU and disable otherwise?

Copy link
Contributor Author

@vlad17 vlad17 Jun 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there are f64 tests auto-triggered by the github action matrix (via env var), so I think all I should really need to do here is adjust the epsilons.

@shoyer
Copy link
Collaborator

shoyer commented Jun 4, 2022

A couple high level thoughts:

  1. It would be nice to support matrix-free linear operators defined on pytrees, like the solvers in jax.scipy.sparse.linalg. For example, imagine solving for the largest eigenvalue pairs of a function that evaluates a linearized neural network.
  2. It would be nice to support an iterative interface with init and update functions, similar to the optimizers in Optax. This sort of inversion of control provides valuable flexibility in cases where a full eigenvalue solve from scratch is prohibitively expensive. For example, imagine performing a single LOBPCG step to re-estimate the top eigenvalue after each gradient descent step when training a neural net, as is done in spectral normalization.

Neither of these are deal breakers for the first iteration, though.

@vlad17
Copy link
Contributor Author

vlad17 commented Jun 4, 2022

It would be nice to support matrix-free linear operators defined on pytrees, like the solvers in jax.scipy.sparse.linalg. For example, imagine solving for the largest eigenvalue pairs of a function that evaluates a linearized neural network.

Great idea, this should be easy to add and now that I think about it, and I think Shankar (skrishnan@google.com) was another user who'd immediately benefit from an interface like that.

It would be nice to support an iterative interface with init and update functions, similar to the optimizers in Optax. This sort of inversion of control provides valuable flexibility in cases where a full eigenvalue solve from scratch is prohibitively expensive. For example, imagine performing a single LOBPCG step to re-estimate the top eigenvalue after each gradient descent step when training a neural net, as is done in spectral normalization.

Could you elaborate on this? I'd be very willing to do this as a follow-on if I had a user to work with on their ideal API for this, but as-stated I don't quite see how setting the initial X to the previous value and maximum iteration m=1 wouldn't work.

@shoyer
Copy link
Collaborator

shoyer commented Jun 5, 2022

It would be nice to support an iterative interface with init and update functions, similar to the optimizers in Optax. This sort of inversion of control provides valuable flexibility in cases where a full eigenvalue solve from scratch is prohibitively expensive. For example, imagine performing a single LOBPCG step to re-estimate the top eigenvalue after each gradient descent step when training a neural net, as is done in spectral normalization.

Could you elaborate on this? I'd be very willing to do this as a follow-on if I had a user to work with on their ideal API for this, but as-stated I don't quite see how setting the initial X to the previous value and maximum iteration m=1 wouldn't work.

Wouldn't we want to calculate the matrix P from the previous iteration?

@vlad17
Copy link
Contributor Author

vlad17 commented Jun 5, 2022

Ah, gotcha. I guess that makes sense, and it seems like it'd be a matter of exposing the body function in the current method. Maybe it'd be best to wait until the API solidifies a little though, since largest=False option is incoming, and it'd be good to figure out how that iterative interface would work with a preconditioner.

@vlad17
Copy link
Contributor Author

vlad17 commented Jun 8, 2022

@shoyer @tabakg just updated the pr on a flight home, let me know what you think of the new interface and tests.

@vlad17
Copy link
Contributor Author

vlad17 commented Jun 14, 2022

Here are some cool curves (which can be generated from the tests) of jax vs scipy f32 on 1000x1000 versions of the test matrices for top-10 eigs (I set convergence tol to 0 which is why nothing converges)

linear, geom, and ring laplacian

clustered eigenvals

@lobpcg
Copy link

lobpcg commented Jun 14, 2022

Here are some cool curves (which can be generated from the tests) of jax vs scipy f32 on 1000x1000 versions of the test matrices for top-10 eigs (I set convergence tol to 0 which is why nothing converges)

I would like to reproduce these tests in SciPy standalone and check if some bugs need to be fixed there, since the runs appear too unstable. Could you please upload the code that calls SciPy for these tests? Which version of SciPy have you used to get these plots?

@vlad17
Copy link
Contributor Author

vlad17 commented Jun 14, 2022

@lobpcg I used 1.8.0 (did anything change in 1.8.1?). The examples are all the same as the unit tests in the PR but size 1000.

The actual colab to make the viz depends on some Google-internal features for unrelated things, but I can try to find some time to clean up the notebook for a public-facing version. Would it be more appropriate to post that as a scipy github issue? Just to keep this thread focussed on jax.

@lobpcg
Copy link

lobpcg commented Jun 14, 2022

@lobpcg I used 1.8.0 (did anything change in 1.8.1?). The examples are all the same as the unit tests in the PR but size 1000.

The actual colab to make the viz depends on some Google-internal features for unrelated things, but I can try to find some time to clean up the notebook for a public-facing version. Would it be more appropriate to post that as a scipy github issue? Just to keep this thread focussed on jax.

1.8.0 is representative. I made changes in lobpcg specifically for float32, but that was before.

Of course if you create the reproducible issue in SciPy that would be ideal. No need to include the code to make the actual plots, just please add a reference to your post with the plots above and a ping to me.

If SciPy fails on smaller sizes like 100, all those would be good examples to include into SciPy as unit tests, if I find a fix so that they all run.

@vlad17
Copy link
Contributor Author

vlad17 commented Jun 15, 2022

@lobpcg filled out scipy/scipy#16408 with just the cases, no viz.

Copy link
Contributor

@rmlarsen rmlarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I

- Despite increased iteration cost, we always maintain an orthonormal basis
for the block search directions.
- We change the convergence criterion; see the `tol` argument.
- Soft locking is intentionally not implemented; it relies on choosing an
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe refer to a paper where soft locking is introduced here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't a canonical link (a very long researchgate URL is all I could find since the original host is gone). So I left a DOI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, the following link is the original source I believe, but broken:

http://math.cudenver.edu/%CB%9Caknyazev/research/conf/cm04%20soft%20locking/cm04.pdf

Research gate has a not-so-pretty URL but works:
https://www.researchgate.net/publication/343530965_Hard_and_soft_locking_Hard_and_soft_locking_in_iterative_methods_for_symmetric_eigenvalue_problems

Perhaps @lobpcg might have a better reference for this?

action.
X : An `(n, k)` array representing the initial search directions for the `k`
desired top eigenvectors. This need not be orthogonal, but must be
linearly independent.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"linearly independent" is not a very precise term in finite precision. I assume the method gradually breaks down as cond(X) increases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I orthonormalize it upon entry, so technically it's "linearly independent enough, such that the orthogonalization routine does not decide it's rank-deficient according to its cutoffs". I'll try to phrase that concisely.

subspace of) the Krylov basis `{X, A X, A^2 X, ..., A^m X}`.
tol : A float convergence tolerance; an eigenpair `(lambda, v)` is converged
when its residual L2 norm `r = |A v - lambda v|` is below
`tol * 10 * n * (lambda + |A v|)`, which
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The factor n is probably quite conservative. Using a probabilistic argument it might be worth trying to scale as sqrt(n) instead. But there might be counter examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline: n just happened to work best for making the tolerances on the unit test all be around floating point epsilon (and scipy usually uses this factor for eigenvalue routines)

return w[::-1], V[:, ::-1]


def _svqb(X):
Copy link
Contributor

@rmlarsen rmlarsen Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a docstring for this method, it is quite non-trivial. Please write out the math instead of just giving a reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



def _project_out(basis, U):
# "twice is enough" from shoyer's reference:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#
# Usually it requires solving the complicated standard eigensystem
# U^-T S^T A S U^-1 @ Q = w * Q and then backsolving V = U^-1 Q,
# but if S is standard orthonormal then we just need to find
Copy link
Contributor

@rmlarsen rmlarsen Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is probably more confusing than helpful. Maybe just mention that we keep S orthonormal, in which case we just need to solve the projected eigenvalue problem for S.T@A@S to obtain the Ritz values and vectors of A w.r.t . span(S).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#
# https://epubs.siam.org/doi/abs/10.1137/0725014
# https://www.jstage.jst.go.jp/article/ipsjdc/2/0/2_0_298/_article
n, k = X.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write out the math here instead of just citing the papers. The code is not particularly readable with all the concatenations etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Derives a truncated orthonormal basis for `X`.
SVQB [1] is an accelerator-friendly orthonormalization procedure, which
squares the matrix `C = X.T @ X` and computes an eigenbasis for a smaller
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could use QR instead of an eigendecomposition if X is full rank, and even Cholesky if X is sufficiently well-conditioned. Both these options would be more accelerator-friendly.

See, e.g.: https://epubs.siam.org/doi/abs/10.1137/18M1218212

However, I don't that you can know that these are safe to use. But it's worth thinking about as it might allow you to work with larger k.

Copy link
Contributor

@rmlarsen rmlarsen Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rank revealing QR would be a nice option, if it was available in JAX. It might be worth adding. A version based on the randomized projection approach could be made accelerator-friendly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By QR here, you mean QR-ing X itself not its square right? At one point early in the algorithm development process I tried that but decided against it b/c the speed wasn't satisfactory. Cholesky on XTX would be faster but rank deficiency is specifically a case that needs to be handled here.

That said, I'll put trying the QR approach out again on my TODO list. It's worth revisiting and I don't have hard numbers saying it's too slow. And it avoids squaring.

I'm really intrigued by the randomized projection you're mentioning. Do you have a reference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a reference for the RRQR with randomized projection: https://arxiv.org/abs/2008.04447

You are right that computing the QR decomposition of X is likely slower, since X.T@X is so fast on an accelerator. But avoiding the squaring would be nice. Just some things to experiment with, I guess.

are mutually orthonormal.
"""

# See Sec. 6.9 of The Symmetric Eigenvalue Problem by Beresford Parlett [1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, very nice.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jun 29, 2022
@vlad17 vlad17 force-pushed the topk branch 2 times, most recently from b7c5016 to 4bf25e5 Compare June 30, 2022 02:47
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.

This is a partial implementation of the similar [scipy lobpcg
function](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lobpcg.html).
@copybara-service copybara-service bot merged commit 4446c73 into jax-ml:main Jun 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

6 participants