Skip to content

Commit a7cc468

Browse files
authored
AutoencoderKL: clamp indices of blend_h and blend_v to input size (#2660)
1 parent 07a0c1c commit a7cc468

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

‎src/diffusers/models/autoencoder_kl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,12 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
190190
return DecoderOutput(sample=decoded)
191191

192192
def blend_v(self, a, b, blend_extent):
193-
for y in range(blend_extent):
193+
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
194194
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
195195
return b
196196

197197
def blend_h(self, a, b, blend_extent):
198-
for x in range(blend_extent):
198+
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
199199
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
200200
return b
201201

‎tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,12 @@ def test_stable_diffusion_vae_tiling(self):
445445

446446
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1
447447

448+
# test that tiled decode works with various shapes
449+
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
450+
for shape in shapes:
451+
zeros = torch.zeros(shape).to(device)
452+
sd_pipe.vae.decode(zeros)
453+
448454
def test_stable_diffusion_negative_prompt(self):
449455
device = "cpu" # ensure determinism for the device-dependent torch.Generator
450456
components = self.get_dummy_components()

0 commit comments

Comments
 (0)