-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[tests] Add test slices for Wan #11920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
503ca81
7c91644
2a6efc0
ff27aea
c18b63c
3f6fb29
2addc4b
9657717
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,6 @@ | |
|
||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
from transformers import AutoTokenizer, T5EncoderModel | ||
|
@@ -123,11 +122,15 @@ def test_inference(self): | |
inputs = self.get_dummy_inputs(device) | ||
video = pipe(**inputs).frames | ||
generated_video = video[0] | ||
|
||
self.assertEqual(generated_video.shape, (17, 3, 16, 16)) | ||
expected_video = torch.randn(17, 3, 16, 16) | ||
max_diff = np.abs(generated_video - expected_video).max() | ||
self.assertLessEqual(max_diff, 1e10) | ||
|
||
# fmt: off | ||
expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195]) | ||
# fmt:on | ||
|
||
generated_slice = generated_video.flatten() | ||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious: why do we have to slice and cat? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a convention I've been following in most of my test PRs. It checks some start and end values of the tensor for correctness, and I think that's better than just checking at one side (for example, sometimes when writing triton, part of the logic is correct and it will work correctly for most of the tensor, but if the load masks and default value are not correct, the values at the end of tensor may be computed incorrectly. By just checking the start or end of the tensor, you may get a false impression that the algorithm is correct. I think I picked this way of checking slices from triton unit tests or somewhere similar, i don't remember anymore) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow, that is indeed better. Thanks for the note! I will adapt this. |
||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) | ||
|
||
@unittest.skip("Test not supported") | ||
def test_attention_slicing_forward_pass(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not important for this PR. But at some point, I think we should switch to
torch_device
instead of fixing devices like this. I see downsides and upsides:pipe
can be run on different accelerators (depending on what we get fortorch_device
.expected_slice
will vary but I guess we can leverage a combination of cosine similarity based checks, enforcement of determinism (like this).No need to fret for this PR. We will get to it later.