From b5a79a818275d1743a39d6ed3dba484cda68b301 Mon Sep 17 00:00:00 2001 From: blacksheep-Aristotle Date: Mon, 18 Aug 2025 14:55:21 +0800 Subject: [PATCH] update expert parallel init logic --- paddlenlp/trainer/training_args.py | 243 ++++++++++++++---- .../transformers/deepseek_v2/modeling.py | 15 +- paddlenlp/transformers/moe_layer.py | 13 +- paddlenlp/transformers/moe_utils.py | 16 +- 4 files changed, 218 insertions(+), 69 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 30a3e7b3dc62..1cd0fae524aa 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1180,7 +1180,125 @@ def __post_init__(self): if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1: raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") - self._post_init_parallel_degree() + self.use_hybrid_parallel = False + + if isinstance(self.sharding, bool): + self.sharding = "stage1" if self.sharding else "" + if isinstance(self.sharding, str): + self.sharding = [ShardingOption(s) for s in self.sharding.split()] + if self.sharding == [ShardingOption.OFFLOAD]: + raise ValueError( + "`--sharding offload` can't work on its own. It needs to be added to `--sharding stage2` or " + '`--sharding stage3`. For example, `--sharding "stage2 offload"`.' + ) + elif len(self.sharding) > (ShardingOption.OFFLOAD in self.sharding) + 1: + raise ValueError("`--sharding` recived too many arguments.") + + if self.sharding_degree > 0: + warnings.warn("`sharding_degree` is deprecated, please use `sharding_parallel_degree`") + self.sharding_parallel_degree = max(self.sharding_degree, self.sharding_parallel_degree) + self.data_parallel_degree = 1 + + delattr(self, "sharding_degree") + + if len(self.sharding) == 0 and self.sharding_parallel_degree > 0: + warnings.warn("`--sharding_parallel_degree` is useful only when `--sharding` is specified.") + + world_size = paddle.distributed.get_world_size() + + if world_size > 1: + tensor_parallel_degree = max(self.tensor_parallel_degree, 1) + sep_parallel_degree = max(self.sep_parallel_degree, 1) + context_parallel_degree = max(self.context_parallel_degree, 1) + pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) + expert_parallel_degree = max(self.expert_parallel_degree, 1) + expert_tensor_parallel_degree = max(self.expert_tensor_parallel_degree, 1) + + # TODO(@gexiao): support expert_tensor_parallel_degree > 1 in the future + assert ( + expert_tensor_parallel_degree == 1 + ), f"Currently only support expert_tensor_parallel_degree=1, but got expert_tensor_parallel_degree of {expert_tensor_parallel_degree}" + + assert ( + world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0 + ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}." + + assert not ( + sep_parallel_degree > 1 and context_parallel_degree > 1 + ), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}." + + if self.sharding_parallel_degree == -1: + if len(self.sharding) > 0: + self.sharding_parallel_degree = world_size // ( + tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + ) + + sharding_parallel_degree = max(self.sharding_parallel_degree, 1) + if sharding_parallel_degree == 1 and len(self.sharding) > 0: + logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") + self.sharding = [] + + self.data_parallel_degree = world_size // ( + sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree + ) + + if expert_parallel_degree > 1: + moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree) + assert ( + self.expert_tensor_parallel_degree <= 1 + ), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1" + else: + moe_sharding_parallel_degree = 1 + moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1) + if moe_sharding_parallel_degree > 1 and self.data_parallel_degree > 1: + raise NotImplementedError( + f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: {self.data_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." + ) + + if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1: + assert ( + sharding_parallel_degree % moe_sharding_parallel_degree == 0 + ), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}." + + assert not ( + self.data_parallel_degree > 1 and expert_parallel_degree > 1 + ), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}." + + if ( + sharding_parallel_degree > 1 + or tensor_parallel_degree > 1 + or pipeline_parallel_degree > 1 + or self.sep_parallel_degree > 1 + or self.context_parallel_degree > 1 + or expert_parallel_degree > 1 + or expert_tensor_parallel_degree > 1 + ): + self.use_hybrid_parallel = True + self.sharding_parallel_degree = sharding_parallel_degree + self.tensor_parallel_degree = tensor_parallel_degree + self.pipeline_parallel_degree = pipeline_parallel_degree + self.sep_parallel_degree = sep_parallel_degree + self.context_parallel_degree = context_parallel_degree + self.expert_parallel_degree = expert_parallel_degree + self.expert_tensor_parallel_degree = expert_tensor_parallel_degree + self.moe_sharding_parallel_degree = moe_sharding_parallel_degree + + if not self.use_hybrid_parallel: + self.sharding = [] + self.sharding_parallel_degree = -1 + self.tensor_parallel_degree = -1 + self.pipeline_parallel_degree = -1 + self.sep_parallel_degree = -1 + self.context_parallel_degree = -1 + self.expert_parallel_degree = -1 + self.expert_tensor_parallel_degree = -1 + + if self.hybrid_parallel_topo_order is None: + self.hybrid_parallel_topo_order = "sharding_first" + assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"] + + if self.use_hybrid_parallel and self.enable_auto_parallel: + self.use_hybrid_parallel = False if self.to_static: assert world_size == 1 or self.enable_auto_parallel, ( @@ -1383,6 +1501,17 @@ def is_segment_parallel_supported(): logger.warning("segment parallel is not supported!!!, Ignore it.") return support_sep + def is_context_parallel_supported(): + import inspect + + members = [ + name for (name, date) in inspect.getmembers(fleet.base.topology.EPHybridCommunicateGroup) + ] + support_cp = "get_context_parallel_world_size" in members + if not support_cp: + logger.warning("context parallel is not supported!!! Ignore it.") + return support_cp + if self.hybrid_parallel_topo_order == "pp_first": if is_segment_parallel_supported(): order = ["dp", "pp", "sharding", "sep", "mp"] @@ -1394,17 +1523,31 @@ def is_segment_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - order = order[1:-1] + ["dp", "mp"] + if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: + if is_context_parallel_supported(): + order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"] + else: + order = ["sharding", "moe_sharding", "pp", "sep", "dp", "ep", "mp"] + else: + order = ["sharding", "pp", "sep", "dp", "mp"] - if is_segment_parallel_supported(): + if is_context_parallel_supported(): hybrid_configs = { "dp_degree": self.data_parallel_degree, "mp_degree": self.tensor_parallel_degree, "pp_degree": self.pipeline_parallel_degree, "sharding_degree": self.sharding_parallel_degree, - "sep_degree": self.sep_parallel_degree - if self.sep_parallel_degree > 1 - else self.context_parallel_degree, + "sep_degree": self.sep_parallel_degree, + "cp_degree": self.context_parallel_degree, + "order": order, + } + elif is_segment_parallel_supported(): + hybrid_configs = { + "dp_degree": self.data_parallel_degree, + "mp_degree": self.tensor_parallel_degree, + "pp_degree": self.pipeline_parallel_degree, + "sharding_degree": self.sharding_parallel_degree, + "sep_degree": self.sep_parallel_degree, "order": order, } else: @@ -1416,6 +1559,13 @@ def is_segment_parallel_supported(): "order": order, } + if self.expert_parallel_degree > 1: + assert ( + self.use_expert_parallel is True and self.moe_sharding_parallel_degree >= 0 + ), f"invalid expert_parallel_degree {self.expert_parallel_degree} and use_expert_paralle:{self.use_expert_parallel}." + hybrid_configs["ep_degree"] = self.expert_parallel_degree + hybrid_configs["moe_sharding_degree"] = self.moe_sharding_parallel_degree + try: if self.split_norm_comm: hybrid_configs["split_norm_comm"] = True @@ -2052,47 +2202,12 @@ def _post_init_parallel_degree(self): self.use_hybrid_parallel = False def add_moe_comm_group(self): - hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs + # NOTE(zhangweilong):move init_moe_group logic to paddle fleet.init + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() + moe_grad_group = fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() hcg = fleet.get_hybrid_communicate_group() - topo = hcg._topo - sharding_parallel_groups = topo.get_comm_list("sharding") - experts_replicas = self.sharding_parallel_degree // self.expert_parallel_degree - - # init experts groups inside all sharding groups - for ranks_in_current_sharding_group in sharding_parallel_groups: - # init experts parallel groups (dispatch & combine) - for i in range(experts_replicas): - rank_indices = list(range(i * self.expert_parallel_degree, (i + 1) * self.expert_parallel_degree)) - ranks = [ranks_in_current_sharding_group[i] for i in rank_indices] - if message2nccl_config is not None and hybrid_configs.get("ep_configs", None) is not None: - group = dist.new_group( - ranks=ranks, nccl_config=message2nccl_config(hybrid_configs["ep_configs"].nccl_config, "ep") - ) - else: - group = dist.new_group(ranks=ranks) - if dist.get_rank() in ranks: - assert not hasattr(hcg, "expert_parallel_group"), "expert_parallel_group can not be set repeate" - setattr(hcg, "expert_parallel_group", group) - - # init experts gradients comm groups - for i in range(self.expert_parallel_degree): - rank_indices = list(range(i, self.sharding_parallel_degree, self.expert_parallel_degree)) - ranks = [ranks_in_current_sharding_group[i] for i in rank_indices] - if message2nccl_config is not None and hybrid_configs.get("ep_configs", None) is not None: - group = dist.new_group( - ranks=ranks, - nccl_config=message2nccl_config(hybrid_configs["ep_configs"].grad_nccl_config, "ep_grad"), - ) - else: - group = dist.new_group(ranks=ranks) - if dist.get_rank() in ranks: - assert not hasattr(hcg, "expert_grad_comm_group"), "expert_grad_comm_group can not be set repeate" - setattr(hcg, "expert_grad_comm_group", group) - - assert hasattr(hcg, "expert_parallel_group") and hasattr(hcg, "expert_grad_comm_group") - logger.info( - f"experts groups are created, expert_parallel_group: {hcg.expert_parallel_group}, expert_grad_comm_group: {hcg.expert_grad_comm_group}" - ) + setattr(hcg, "expert_parallel_group", moe_group) + setattr(hcg, "expert_grad_comm_group", moe_grad_group) def __str__(self): self_as_dict = asdict(self) @@ -2200,6 +2315,28 @@ def pipeline_parallel_rank(self): else: return 0 + @property + def expert_parallel_rank(self): + if self.use_hybrid_parallel: + hcg = fleet.get_hybrid_communicate_group() + if hasattr(hcg, "get_expert_parallel_rank"): + return max(hcg.get_expert_parallel_rank(), 0) + else: + return 0 + else: + return 0 + + @property + def context_parallel_rank(self): + if self.use_hybrid_parallel: + hcg = fleet.get_hybrid_communicate_group() + if hasattr(hcg, "get_context_parallel_rank"): + return max(hcg.get_context_parallel_rank(), 0) + else: + return 0 + else: + return 0 + def _format_name(self, prefix, rank, degree): size = 2 return f"{prefix}{rank:0>{size}d}" @@ -2214,7 +2351,7 @@ def optimizer_name_suffix(self): name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) if self.sharding_parallel_degree > 1: name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree)) - if self.use_expert_parallel: + if self.use_expert_parallel and self.expert_parallel_degree <= 1: name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: @@ -2230,7 +2367,7 @@ def weight_name_suffix(self): name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree)) if self.pipeline_parallel_degree > 1: name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) - if self.use_expert_parallel: + if self.use_expert_parallel and self.expert_parallel_degree <= 1: name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) @@ -2239,7 +2376,9 @@ def weight_name_suffix(self): return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None - def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None): + def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None, sharding_parallel_degree=None): + if sharding_parallel_degree is None: + sharding_parallel_degree = self.sharding_parallel_degree if self.use_hybrid_parallel: name = [] if self.tensor_parallel_degree > 1: @@ -2249,12 +2388,12 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None): pp_id = self.pipeline_parallel_rank assert isinstance(pp_id, int) name.append(self._format_name("pp", pp_id, self.pipeline_parallel_degree)) - if self.sharding_parallel_degree > 1: + if sharding_parallel_degree > 1: if shard_id is None: shard_id = self.sharding_parallel_rank assert isinstance(shard_id, int) - name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree)) - if self.use_expert_parallel: + name.append(self._format_name("shard", shard_id, sharding_parallel_degree)) + if self.use_expert_parallel and self.expert_parallel_degree <= 1: if moe_id is None: moe_id = self.data_parallel_rank assert isinstance(moe_id, int) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index ca71b478ee49..c9568409be62 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -785,8 +785,10 @@ def __init__(self, config: DeepseekV2Config): ) # (LiuTing) only support either tp or ep. - moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() expert_parallel_degree = dist.get_world_size(moe_group) + moe_grad_group = fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() + expert_parallel_degree = 1 if expert_parallel_degree < 0 else expert_parallel_degree act_tp_shard = config.tensor_parallel_degree > 1 and expert_parallel_degree <= 1 super().__init__( @@ -800,7 +802,12 @@ def __init__(self, config: DeepseekV2Config): }, gate=gate, capacity=2.0, + moe_group="expert", ) + + for p in self.experts.parameters(): + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -838,8 +845,8 @@ def __init__(self, config: DeepseekV2Config): ) hcg = fleet.get_hybrid_communicate_group() - moe_group = hcg.expert_parallel_group - moe_grad_group = hcg.expert_grad_comm_group + moe_group = hcg.get_expert_parallel_group() + moe_grad_group = hcg.get_moe_sharding_parallel_group() super().__init__( config=config, @@ -1467,7 +1474,7 @@ def get_tensor_parallel_split_mappings(num_layers): base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) # moe unit routed experts - moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() expert_parallel_degree = dist.get_world_size(moe_group) if expert_parallel_degree <= 1: for e_i in range(config.n_routed_experts): diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 340fba1f5245..861bea8d2e0e 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -176,12 +176,13 @@ def __init__( except AttributeError: is_fleet_init = False - if ( - is_fleet_init - and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 - and moe_group == "data" - ): - self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + if is_fleet_init and dist.get_world_size() > 1: + if moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group + else: + assert NotImplementedError("moe_group can only be data or expert, but given {}".format(self.moe_group)) self.moe_rank = dist.get_rank(self.moe_group) self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.expert_parallel_degree = dist.get_world_size(self.moe_group) diff --git a/paddlenlp/transformers/moe_utils.py b/paddlenlp/transformers/moe_utils.py index 466591b0638d..d82654f6375b 100644 --- a/paddlenlp/transformers/moe_utils.py +++ b/paddlenlp/transformers/moe_utils.py @@ -18,6 +18,11 @@ import paddle +try: + from paddle import scatter_add_ +except ImportError: + scatter_add_ = None + def permute( tokens, @@ -91,11 +96,8 @@ def unpermute( # Create an output tensor filled with zeros output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype) # Scatter add the permuted_input back to the original positions - output_tokens.put_along_axis_( - axis=0, - indices=sorted_indices.unsqueeze(1).expand([-1, hidden]), - values=permuted_tokens, - reduce="add", - include_self=True, - ) + if scatter_add_ is not None: + scatter_add_(output_tokens, sorted_indices, permuted_tokens) + else: + output_tokens.scatter_(index=sorted_indices, updates=permuted_tokens, overwrite=False) return output_tokens