Skip to content

Conversation

oneflyingfish
Copy link

What does this PR do?

Although PR#12335 has corrected the issue of incorrect input shape, there are still problems that remain unsolved. For instance, after tile_decode, the absence of unpatchify leads to incorrect shapes, and it has not been noticed that the first frame of the WAN2.2 VAE decoding stage has different parameters, otherwise it would cause shape errors. Additionally, in PR#12335, the meaning of self.spatial_compression_ratio has been modified, resulting in incorrect tiled partitioning. In the new submission, I merged the previous PR and it passed the tests successfully.

When I use enable_tiling() in autoencoder_kl_wan.AutoencoderKLWan, the inference would report compute error. I have identified the cause of this error and noticed other potential implementation issues during the repair process. Therefore, I have implemented the repair code.

output shape: torch.Size([1, 3, 81, 736, 1280])
fail to inference vae with vae.enable_tiling() Given groups=1, weight of size [160, 12, 3, 3, 3], expected input[1, 3, 3, 258, 258] to have 12 channels, but got 3 channels instead

Bug reproduction code:

from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
import torch
import os

dtype = torch.bfloat16
device = torch.device(f"cuda:0")
cpu_device = torch.device("cpu")

weight = "/path/to/vae"
vae = (
    AutoencoderKLWan.from_pretrained(
        weight,
        torch_dtype=dtype,
    )
    .eval()
    .to(dtype)
)

with torch.no_grad():
    torch.manual_seed(0)
    dummy_input = (torch.randn((1, 3, 81, 736, 1280),device=device,dtype=dtype)-0.5)/0.5    # B,C,F,H,W

    torch.manual_seed(0)
    # encode
    latent = vae.encode(dummy_input).latent_dist.mode() # type: torch.Tensor

    # decode
    gen_raw = vae.decode(latent, return_dict=False)[0]
    print(f"output shape: {gen_raw.shape}")
    
    # run tiling
    try:
      	torch.manual_seed(0)
        vae.enable_tiling()
        # encode
        latent = vae.encode(dummy_input).latent_dist.mode() # type: torch.Tensor

        # decode
        tile_gen = vae.decode(latent, return_dict=False)[0]
    except Exception as ex:
        print("fail to inference vae with vae.enable_tiling(), error info", ex)

Verification

via tool VCmpTool

比对

raw video:

10120244.mp4

gen video by tiled:

Only the first 81 frames

decode_gen_0_tiling.mp4

Fixes

  • shape error while use enable_tiling() that indicates the input channel number is incorrect
  • tiling error caused by modufy the self.spatial_compression_ratio rate mean in PR#12335
  • abnormal decode process in tiled_encode (without considering the first frame, default implementation is inconsistent), also cause shape errors
  • The positions of the quant_conv and post_quant_conv operators were adjusted to minimize the error as much as possible.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @yiyixuxu @DN6

@sayakpaul sayakpaul requested a review from yiyixuxu September 17, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant