prepare for sync
This commit is contained in:
parent
6da55e8f87
commit
ef4d1ddda4
@ -122,17 +122,21 @@ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
|
x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
batch_size = q.shape[0]
|
B, L, H, C = q.shape
|
||||||
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
|
|
||||||
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
|
q = q.flatten(0, 1)
|
||||||
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
|
k = k.flatten(0, 1)
|
||||||
|
v = v.flatten(0, 1)
|
||||||
|
|
||||||
if sageattn_varlen is not None:
|
if sageattn_varlen is not None:
|
||||||
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
||||||
elif flash_attn_varlen_func is not None:
|
elif flash_attn_varlen_func is not None:
|
||||||
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('No Attn Installed!')
|
raise NotImplementedError('No Attn Installed!')
|
||||||
x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
|
|
||||||
|
x = x.unflatten(0, (B, L))
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -926,23 +930,22 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM
|
|||||||
encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
|
encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
|
||||||
encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
|
encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
|
||||||
|
|
||||||
with torch.no_grad():
|
if batch_size == 1:
|
||||||
if batch_size == 1:
|
# When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
|
||||||
# When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
|
# If they are not same, then their impls are wrong. Ours are always the correct one.
|
||||||
# If they are not same, then their impls are wrong. Ours are always the correct one.
|
text_len = encoder_attention_mask.sum().item()
|
||||||
text_len = encoder_attention_mask.sum().item()
|
encoder_hidden_states = encoder_hidden_states[:, :text_len]
|
||||||
encoder_hidden_states = encoder_hidden_states[:, :text_len]
|
attention_mask = None, None, None, None
|
||||||
attention_mask = None, None, None, None
|
else:
|
||||||
else:
|
img_seq_len = hidden_states.shape[1]
|
||||||
img_seq_len = hidden_states.shape[1]
|
txt_seq_len = encoder_hidden_states.shape[1]
|
||||||
txt_seq_len = encoder_hidden_states.shape[1]
|
|
||||||
|
|
||||||
cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
|
cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
|
||||||
cu_seqlens_kv = cu_seqlens_q
|
cu_seqlens_kv = cu_seqlens_q
|
||||||
max_seqlen_q = img_seq_len + txt_seq_len
|
max_seqlen_q = img_seq_len + txt_seq_len
|
||||||
max_seqlen_kv = max_seqlen_q
|
max_seqlen_kv = max_seqlen_q
|
||||||
|
|
||||||
attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
|
attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
|
||||||
|
|
||||||
if self.enable_teacache:
|
if self.enable_teacache:
|
||||||
modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
|
modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user