| | from dataclasses import dataclass |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.distributed.fsdp import fully_shard |
| | from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard |
| | from torch.distributed.tensor.parallel import (ColwiseParallel, |
| | PrepareModuleInput, |
| | RowwiseParallel, |
| | SequenceParallel, |
| | parallelize_module) |
| |
|
| |
|
| | @dataclass |
| | class ParallelDims: |
| | dp_replicate_degree: int |
| | dp_shard_degree: int |
| | tp_degree: int |
| | ep_degree: int = 1 |
| |
|
| | def __str__(self) -> str: |
| | s = (f"dp_replicate-{self.dp_replicate_degree}_" |
| | f"dp_shard-{self.dp_shard_degree}_" |
| | f"tp-{self.tp_degree}") |
| | if self.ep_degree > 1: |
| | s += f"_ep-{self.ep_degree}" |
| | return s |
| |
|
| |
|
| | def _construct_device_mesh(parallel_dims: ParallelDims) -> DeviceMesh: |
| | """Constructs a DeviceMesh based on the given parallel dimensions. |
| | |
| | Args: |
| | parallel_dims (ParallelDims): The parallelism configuration. |
| | |
| | Returns: |
| | DeviceMesh: The constructed device mesh. |
| | """ |
| | world_size = dist.get_world_size() |
| | expected_devices = (parallel_dims.dp_replicate_degree * |
| | parallel_dims.dp_shard_degree * |
| | parallel_dims.ep_degree * parallel_dims.tp_degree) |
| | if world_size < expected_devices: |
| | raise ValueError( |
| | f"Not enough devices: found {world_size}, " |
| | f"but expected at least {expected_devices}. ({parallel_dims})") |
| |
|
| | degrees = [ |
| | parallel_dims.dp_replicate_degree, parallel_dims.dp_shard_degree, |
| | parallel_dims.ep_degree, parallel_dims.tp_degree |
| | ] |
| | dim_names = ["dp_replicate", "dp_shard", "ep", "tp"] |
| |
|
| | mesh_shape = [] |
| | mesh_dim_names = [] |
| | for degree, dim_name in zip(degrees, dim_names): |
| | if degree > 1: |
| | mesh_shape.append(degree) |
| | mesh_dim_names.append(dim_name) |
| |
|
| | device_mesh = dist.init_device_mesh("cuda", |
| | mesh_shape, |
| | mesh_dim_names=mesh_dim_names) |
| |
|
| | return device_mesh |
| |
|
| |
|
| | def _apply_tp( |
| | model: torch.nn.Module, |
| | tp_mesh: DeviceMesh, |
| | ): |
| | """Apply tensor parallelism.""" |
| |
|
| | |
| | |
| |
|
| | assert type(model).__name__ == "MotifForCausalLM" |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | parallelize_module( |
| | model, |
| | tp_mesh, |
| | { |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | "model.norm": |
| | SequenceParallel(), |
| | "output": |
| | ColwiseParallel( |
| | input_layouts=Shard(1), |
| | output_layouts=Shard(-1), |
| | use_local_output=False, |
| | ), |
| | }, |
| | ) |
| |
|
| | |
| | for transformer_block in model.model.layers: |
| | layer_plan = { |
| | "input_layernorm": |
| | SequenceParallel(), |
| | "post_attention_layernorm": |
| | SequenceParallel(), |
| | "self_attn": |
| | PrepareModuleInput( |
| | |
| | input_layouts=(Shard(1), Replicate(), None, None, None), |
| | desired_input_layouts=(Replicate(), Replicate(), None, None, |
| | None), |
| | ), |
| | "self_attn.q_proj": |
| | ColwiseParallel(), |
| | "self_attn.k_proj": |
| | ColwiseParallel(), |
| | "self_attn.v_proj": |
| | ColwiseParallel(), |
| | "self_attn.o_proj": |
| | RowwiseParallel(output_layouts=Shard(1)), |
| | "mlp": |
| | PrepareModuleInput( |
| | input_layouts=(Shard(1), ), |
| | desired_input_layouts=(Replicate(), ), |
| | ), |
| | "mlp.gate_proj": |
| | ColwiseParallel(), |
| | "mlp.down_proj": |
| | RowwiseParallel(output_layouts=Shard(1)), |
| | "mlp.up_proj": |
| | ColwiseParallel(), |
| | } |
| |
|
| | parallelize_module( |
| | module=transformer_block, |
| | device_mesh=tp_mesh, |
| | parallelize_plan=layer_plan, |
| | ) |
| |
|
| |
|
| | def _apply_fsdp( |
| | model: torch.nn.Module, |
| | dp_mesh: DeviceMesh, |
| | ): |
| | for layer in model.model.layers: |
| | fully_shard(layer, mesh=dp_mesh) |
| | layer.reshard() |
| | fully_shard(model, mesh=dp_mesh) |
| | model.reshard() |
| |
|
| |
|
| | def parallelize_llama4(model: torch.nn.Module, |
| | parallel_dims: ParallelDims) -> torch.nn.Module: |
| | """Parallelize the torchtitan Llama4 MoE model using torchtitan's |
| | ``parallelize_llama`` directly. |
| | """ |
| | from torchtitan.config import JobConfig |
| | from torchtitan.distributed import ParallelDims as TTParallelDims |
| | from torchtitan.models.llama4.infra.parallelize import parallelize_llama |
| |
|
| | world_size = dist.get_world_size() |
| |
|
| | |
| | |
| | tt_dp_shard = parallel_dims.dp_shard_degree * parallel_dims.ep_degree |
| |
|
| | tt_dims = TTParallelDims( |
| | dp_replicate=parallel_dims.dp_replicate_degree, |
| | dp_shard=tt_dp_shard, |
| | cp=1, |
| | tp=parallel_dims.tp_degree, |
| | pp=1, |
| | ep=parallel_dims.ep_degree, |
| | etp=1, |
| | world_size=world_size, |
| | ) |
| |
|
| | |
| | job_config = JobConfig() |
| | job_config.training.mixed_precision_param = "float32" |
| | job_config.activation_checkpoint.mode = "none" |
| | job_config.compile.enable = False |
| | job_config.parallelism.disable_loss_parallel = True |
| |
|
| | parallelize_llama(model, tt_dims, job_config) |
| | return model |
| |
|
| |
|
| | def parallelize_motif(model: torch.nn.Module, |
| | parallel_dims: ParallelDims) -> torch.nn.Module: |
| | """Parallelize the Motif model according to the given parallel dimensions. |
| | |
| | Args: |
| | model (torch.nn.Module): The Motif model to be parallelized. |
| | parallel_dims (ParallelDims): The parallelism configuration. |
| | |
| | Returns: |
| | torch.nn.Module: The parallelized Motif model. |
| | """ |
| |
|
| | mesh = _construct_device_mesh(parallel_dims) |
| |
|
| | if parallel_dims.tp_degree > 1: |
| | _apply_tp(model, mesh["tp"]) |
| |
|
| | if parallel_dims.dp_shard_degree > 1: |
| | if parallel_dims.dp_replicate_degree > 1: |
| | dp_dim_names = ("dp_replicate", "dp_shard") |
| | else: |
| | dp_dim_names = ("dp_shard", ) |
| | _apply_fsdp(model, mesh[dp_dim_names]) |
| |
|
| | return model |
| |
|
| |
|
| | def parallelize_qk_logits( |
| | qk_logits: dict[int, torch.Tensor], |
| | parallel_dims: ParallelDims, |
| | ) -> dict[int, torch.Tensor]: |
| | """Parallelize the QK logits according to the given parallel dimensions. |
| | |
| | Args: |
| | qk_logits (dict[int, torch.Tensor]): The QK logits to be parallelized. |
| | parallel_dims (ParallelDims): The parallelism configuration. |
| | |
| | Returns: |
| | dict[int, torch.Tensor]: The parallelized QK logits. |
| | """ |
| |
|
| | mesh = _construct_device_mesh(parallel_dims) |
| |
|
| | if parallel_dims.tp_degree > 1: |
| | tp_rank = mesh["tp"].get_local_rank() |
| | placements = [ |
| | Shard(0) if dim_name == "tp" else Replicate() |
| | for dim_name in mesh.mesh_dim_names |
| | ] |
| | for layer_idx, logits in qk_logits.items(): |
| | assert logits.size(0) % parallel_dims.tp_degree == 0 |
| | local_logits = logits.chunk(parallel_dims.tp_degree, |
| | dim=0)[tp_rank].contiguous() |
| |
|
| | qk_logits[layer_idx] = DTensor.from_local( |
| | local_tensor=local_logits, |
| | device_mesh=mesh, |
| | placements=placements, |
| | ) |
| |
|
| | return qk_logits |
| |
|
| |
|
| | def assert_params_equal(actual: torch.nn.Module, |
| | expected: torch.nn.Module) -> None: |
| | """Asserts that the parameters of two models are equal. |
| | |
| | Args: |
| | actual (torch.nn.Module): The actual model. |
| | expected (torch.nn.Module): The expected model. |
| | Returns: |
| | None |
| | """ |
| |
|
| | def get_full_param(param: torch.nn.Parameter) -> torch.Tensor: |
| | if isinstance(param.data, DTensor): |
| | return param.data.full_tensor() |
| | return param.data |
| |
|
| | for (name_p, p), (name_s, s) in zip(actual.named_parameters(), |
| | expected.named_parameters()): |
| | p = get_full_param(p.cuda()) |
| | s = get_full_param(s.cuda()) |
| |
|
| | torch.testing.assert_close(p, s, atol=0, rtol=0) |
| |
|