Skip to content

[tests] Add test slices for Wan #11920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 23, 2025
17 changes: 9 additions & 8 deletions tests/pipelines/wan/test_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import gc
import unittest

import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel

Expand All @@ -29,9 +28,7 @@
)

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
)
from ..test_pipelines_common import PipelineTesterMixin


enable_full_determinism()
Expand Down Expand Up @@ -127,11 +124,15 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]

self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)

# fmt: off
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
# fmt: on

generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))

@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
Expand Down
62 changes: 56 additions & 6 deletions tests/pipelines/wan/test_wan_image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import unittest

import numpy as np
import torch
from PIL import Image
from transformers import (
Expand Down Expand Up @@ -147,11 +146,15 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]

self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)

# fmt: off
expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
# fmt: on

generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))

@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
Expand All @@ -162,7 +165,25 @@ def test_inference_batch_single_identical(self):
pass


class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False

def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
Expand Down Expand Up @@ -247,3 +268,32 @@ def get_dummy_inputs(self, device, seed=0):
"output_type": "pt",
}
return inputs

def test_inference(self):
device = "cpu"

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
Comment on lines +273 to +277
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not important for this PR. But at some point, I think we should switch to torch_device instead of fixing devices like this. I see downsides and upsides:

  • Upside: same code can be used for quickly checking if the pipe can be run on different accelerators (depending on what we get for torch_device.
  • Downside: The expected_slice will vary but I guess we can leverage a combination of cosine similarity based checks, enforcement of determinism (like this).

No need to fret for this PR. We will get to it later.

pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))

# fmt: off
expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
# fmt: on

generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))

@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass

@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
13 changes: 8 additions & 5 deletions tests/pipelines/wan/test_wan_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import unittest

import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
Expand Down Expand Up @@ -123,11 +122,15 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]

self.assertEqual(generated_video.shape, (17, 3, 16, 16))
expected_video = torch.randn(17, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)

# fmt: off
expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195])
# fmt:on

generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: why do we have to slice and cat?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a convention I've been following in most of my test PRs. It checks some start and end values of the tensor for correctness, and I think that's better than just checking at one side (for example, sometimes when writing triton, part of the logic is correct and it will work correctly for most of the tensor, but if the load masks and default value are not correct, the values at the end of tensor may be computed incorrectly. By just checking the start or end of the tensor, you may get a false impression that the algorithm is correct. I think I picked this way of checking slices from triton unit tests or somewhere similar, i don't remember anymore)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that is indeed better. Thanks for the note! I will adapt this.

self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))

@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
Expand Down
Loading