Spaces:
Build error
Build error
| import copy | |
| import torch | |
| import torch.distributed as dist | |
| from diffusers import LTXVideoTransformer3DModel | |
| from torch._utils import _get_device_module | |
| from torch.distributed.tensor import DTensor, Replicate | |
| from torch.distributed.tensor.debug import CommDebugMode | |
| from torch.distributed.tensor.device_mesh import DeviceMesh | |
| from torch.distributed.tensor.parallel.api import parallelize_module | |
| from torch.distributed.tensor.parallel.style import ( | |
| ColwiseParallel, | |
| RowwiseParallel, | |
| ) | |
| # from torch.utils._python_dispatch import TorchDispatchMode | |
| DEVICE_TYPE = "cuda" | |
| PG_BACKEND = "nccl" | |
| DEVICE_COUNT = _get_device_module(DEVICE_TYPE).device_count() | |
| def main(world_size: int, rank: int): | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats(rank) | |
| CHANNELS = 128 | |
| CROSS_ATTENTION_DIM = 2048 | |
| CAPTION_CHANNELS = 4096 | |
| NUM_LAYERS = 28 | |
| NUM_ATTENTION_HEADS = 32 | |
| ATTENTION_HEAD_DIM = 64 | |
| # CHANNELS = 4 | |
| # CROSS_ATTENTION_DIM = 32 | |
| # CAPTION_CHANNELS = 64 | |
| # NUM_LAYERS = 1 | |
| # NUM_ATTENTION_HEADS = 4 | |
| # ATTENTION_HEAD_DIM = 8 | |
| config = { | |
| "in_channels": CHANNELS, | |
| "out_channels": CHANNELS, | |
| "patch_size": 1, | |
| "patch_size_t": 1, | |
| "num_attention_heads": NUM_ATTENTION_HEADS, | |
| "attention_head_dim": ATTENTION_HEAD_DIM, | |
| "cross_attention_dim": CROSS_ATTENTION_DIM, | |
| "num_layers": NUM_LAYERS, | |
| "activation_fn": "gelu-approximate", | |
| "qk_norm": "rms_norm_across_heads", | |
| "norm_elementwise_affine": False, | |
| "norm_eps": 1e-6, | |
| "caption_channels": CAPTION_CHANNELS, | |
| "attention_bias": True, | |
| "attention_out_bias": True, | |
| } | |
| # Normal model | |
| torch.manual_seed(0) | |
| model = LTXVideoTransformer3DModel(**config).to(DEVICE_TYPE) | |
| # TP model | |
| model_tp = copy.deepcopy(model) | |
| device_mesh = DeviceMesh(DEVICE_TYPE, torch.arange(world_size)) | |
| print(f"Device mesh: {device_mesh}") | |
| transformer_tp_plan = { | |
| # ===== Condition embeddings ===== | |
| # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(), | |
| # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)), | |
| # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()), | |
| # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())), | |
| # "caption_projection.linear_1": ColwiseParallel(), | |
| # "caption_projection.linear_2": RowwiseParallel(), | |
| # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), | |
| # ===== ===== | |
| } | |
| for block in model_tp.transformer_blocks: | |
| block_tp_plan = {} | |
| # ===== Attention ===== | |
| # 8 all-to-all, 3 all-reduce | |
| # block_tp_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False) | |
| # block_tp_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False) | |
| # block_tp_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False) | |
| # block_tp_plan["attn1.norm_q"] = SequenceParallel() | |
| # block_tp_plan["attn1.norm_k"] = SequenceParallel() | |
| # block_tp_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) | |
| # block_tp_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False) | |
| # block_tp_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False) | |
| # block_tp_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False) | |
| # block_tp_plan["attn2.norm_q"] = SequenceParallel() | |
| # block_tp_plan["attn2.norm_k"] = SequenceParallel() | |
| # block_tp_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) | |
| # ===== ===== | |
| block_tp_plan["ff.net.0.proj"] = ColwiseParallel() | |
| block_tp_plan["ff.net.2"] = RowwiseParallel() | |
| parallelize_module(block, device_mesh, block_tp_plan) | |
| parallelize_module(model_tp, device_mesh, transformer_tp_plan) | |
| comm_mode = CommDebugMode() | |
| batch_size = 2 | |
| num_frames, height, width = 49, 512, 512 | |
| temporal_compression_ratio, spatial_compression_ratio = 8, 32 | |
| latent_num_frames, latent_height, latent_width = ( | |
| (num_frames - 1) // temporal_compression_ratio + 1, | |
| height // spatial_compression_ratio, | |
| width // spatial_compression_ratio, | |
| ) | |
| video_sequence_length = latent_num_frames * latent_height * latent_width | |
| caption_sequence_length = 64 | |
| hidden_states = torch.randn(batch_size, video_sequence_length, CHANNELS, device=DEVICE_TYPE) | |
| encoder_hidden_states = torch.randn(batch_size, caption_sequence_length, CAPTION_CHANNELS, device=DEVICE_TYPE) | |
| encoder_attention_mask = None | |
| timestep = torch.randint(0, 1000, (batch_size, 1), device=DEVICE_TYPE) | |
| inputs = { | |
| "hidden_states": hidden_states, | |
| "encoder_hidden_states": encoder_hidden_states, | |
| "encoder_attention_mask": encoder_attention_mask, | |
| "timestep": timestep, | |
| "num_frames": latent_num_frames, | |
| "height": latent_height, | |
| "width": latent_width, | |
| "rope_interpolation_scale": [1 / (8 / 25), 8, 8], | |
| "return_dict": False, | |
| } | |
| output = model(**inputs)[0] | |
| with comm_mode: | |
| output_tp = model_tp(**inputs)[0] | |
| output_tp = ( | |
| output_tp.redistribute(output_tp.device_mesh, [Replicate()]).to_local() | |
| if isinstance(output_tp, DTensor) | |
| else output_tp | |
| ) | |
| print("Output shapes:", output.shape, output_tp.shape) | |
| print( | |
| "Comparing output:", | |
| rank, | |
| torch.allclose(output, output_tp, atol=1e-5, rtol=1e-5), | |
| (output - output_tp).abs().max(), | |
| ) | |
| print(f"Max memory reserved ({rank=}): {torch.cuda.max_memory_reserved(rank) / 1024**3:.2f} GB") | |
| if rank == 0: | |
| print() | |
| print("get_comm_counts:", comm_mode.get_comm_counts()) | |
| # print() | |
| # print("get_parameter_info:", comm_mode.get_parameter_info()) # Too much noise | |
| print() | |
| print("Sharding info:\n" + "".join(f"{k} - {v}\n" for k, v in comm_mode.get_sharding_info().items())) | |
| print() | |
| print("get_total_counts:", comm_mode.get_total_counts()) | |
| comm_mode.generate_json_dump("dump_comm_mode_log.json", noise_level=1) | |
| comm_mode.log_comm_debug_tracing_table_to_file("dump_comm_mode_tracing_table.txt", noise_level=1) | |
| dist.init_process_group(PG_BACKEND) | |
| WORLD_SIZE = dist.get_world_size() | |
| RANK = dist.get_rank() | |
| torch.cuda.set_device(RANK) | |
| if RANK == 0: | |
| print(f"World size: {WORLD_SIZE}") | |
| print(f"Device count: {DEVICE_COUNT}") | |
| try: | |
| with torch.no_grad(): | |
| main(WORLD_SIZE, RANK) | |
| finally: | |
| dist.destroy_process_group() | |
| # LTXVideoTransformer3DModel( | |
| # (proj_in): Linear(in_features=128, out_features=2048, bias=True) | |
| # (time_embed): AdaLayerNormSingle( | |
| # (emb): PixArtAlphaCombinedTimestepSizeEmbeddings( | |
| # (time_proj): Timesteps() | |
| # (timestep_embedder): TimestepEmbedding( | |
| # (linear_1): Linear(in_features=256, out_features=2048, bias=True) | |
| # (act): SiLU() | |
| # (linear_2): Linear(in_features=2048, out_features=2048, bias=True) | |
| # ) | |
| # ) | |
| # (silu): SiLU() | |
| # (linear): Linear(in_features=2048, out_features=12288, bias=True) | |
| # ) | |
| # (caption_projection): PixArtAlphaTextProjection( | |
| # (linear_1): Linear(in_features=4096, out_features=2048, bias=True) | |
| # (act_1): GELU(approximate='tanh') | |
| # (linear_2): Linear(in_features=2048, out_features=2048, bias=True) | |
| # ) | |
| # (rope): LTXVideoRotaryPosEmbed() | |
| # (transformer_blocks): ModuleList( | |
| # (0-27): 28 x LTXVideoTransformerBlock( | |
| # (norm1): RMSNorm() | |
| # (attn1): Attention( | |
| # (norm_q): RMSNorm() | |
| # (norm_k): RMSNorm() | |
| # (to_q): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (to_k): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (to_v): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (to_out): ModuleList( | |
| # (0): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (1): Dropout(p=0.0, inplace=False) | |
| # ) | |
| # ) | |
| # (norm2): RMSNorm() | |
| # (attn2): Attention( | |
| # (norm_q): RMSNorm() | |
| # (norm_k): RMSNorm() | |
| # (to_q): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (to_k): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (to_v): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (to_out): ModuleList( | |
| # (0): Linear(in_features=2048, out_features=2048, bias=True) | |
| # (1): Dropout(p=0.0, inplace=False) | |
| # ) | |
| # ) | |
| # (ff): FeedForward( | |
| # (net): ModuleList( | |
| # (0): GELU( | |
| # (proj): Linear(in_features=2048, out_features=8192, bias=True) | |
| # ) | |
| # (1): Dropout(p=0.0, inplace=False) | |
| # (2): Linear(in_features=8192, out_features=2048, bias=True) | |
| # ) | |
| # ) | |
| # ) | |
| # ) | |
| # (norm_out): LayerNorm((2048,), eps=1e-06, elementwise_affine=False) | |
| # (proj_out): Linear(in_features=2048, out_features=128, bias=True) | |
| # ) | |