diff --git a/diffusers_helper/models/hunyuan_video_packed.py b/diffusers_helper/models/hunyuan_video_packed.py index f879799..1cb42ab 100644 --- a/diffusers_helper/models/hunyuan_video_packed.py +++ b/diffusers_helper/models/hunyuan_video_packed.py @@ -122,17 +122,21 @@ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) return x - batch_size = q.shape[0] - q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) - k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) - v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) + B, L, H, C = q.shape + + q = q.flatten(0, 1) + k = k.flatten(0, 1) + v = v.flatten(0, 1) + if sageattn_varlen is not None: x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) elif flash_attn_varlen_func is not None: x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) else: raise NotImplementedError('No Attn Installed!') - x = x.view(batch_size, max_seqlen_q, *x.shape[2:]) + + x = x.unflatten(0, (B, L)) + return x @@ -926,23 +930,22 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) - with torch.no_grad(): - if batch_size == 1: - # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want - # If they are not same, then their impls are wrong. Ours are always the correct one. - text_len = encoder_attention_mask.sum().item() - encoder_hidden_states = encoder_hidden_states[:, :text_len] - attention_mask = None, None, None, None - else: - img_seq_len = hidden_states.shape[1] - txt_seq_len = encoder_hidden_states.shape[1] + if batch_size == 1: + # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want + # If they are not same, then their impls are wrong. Ours are always the correct one. + text_len = encoder_attention_mask.sum().item() + encoder_hidden_states = encoder_hidden_states[:, :text_len] + attention_mask = None, None, None, None + else: + img_seq_len = hidden_states.shape[1] + txt_seq_len = encoder_hidden_states.shape[1] - cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) - cu_seqlens_kv = cu_seqlens_q - max_seqlen_q = img_seq_len + txt_seq_len - max_seqlen_kv = max_seqlen_q + cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q - attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv if self.enable_teacache: modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]