Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 191 additions & 52 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,125 @@ def __post_init__(self):
if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1:
raise ValueError("AdamW Mini currently doesn't support tensor parallelism.")

self._post_init_parallel_degree()
self.use_hybrid_parallel = False

if isinstance(self.sharding, bool):
self.sharding = "stage1" if self.sharding else ""
if isinstance(self.sharding, str):
self.sharding = [ShardingOption(s) for s in self.sharding.split()]
if self.sharding == [ShardingOption.OFFLOAD]:
raise ValueError(
"`--sharding offload` can't work on its own. It needs to be added to `--sharding stage2` or "
'`--sharding stage3`. For example, `--sharding "stage2 offload"`.'
)
elif len(self.sharding) > (ShardingOption.OFFLOAD in self.sharding) + 1:
raise ValueError("`--sharding` recived too many arguments.")

if self.sharding_degree > 0:
warnings.warn("`sharding_degree` is deprecated, please use `sharding_parallel_degree`")
self.sharding_parallel_degree = max(self.sharding_degree, self.sharding_parallel_degree)
self.data_parallel_degree = 1

delattr(self, "sharding_degree")

if len(self.sharding) == 0 and self.sharding_parallel_degree > 0:
warnings.warn("`--sharding_parallel_degree` is useful only when `--sharding` is specified.")

world_size = paddle.distributed.get_world_size()

if world_size > 1:
tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
sep_parallel_degree = max(self.sep_parallel_degree, 1)
context_parallel_degree = max(self.context_parallel_degree, 1)
pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)
expert_parallel_degree = max(self.expert_parallel_degree, 1)
expert_tensor_parallel_degree = max(self.expert_tensor_parallel_degree, 1)

# TODO(@gexiao): support expert_tensor_parallel_degree > 1 in the future
assert (
expert_tensor_parallel_degree == 1
), f"Currently only support expert_tensor_parallel_degree=1, but got expert_tensor_parallel_degree of {expert_tensor_parallel_degree}"

assert (
world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0
), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}."

assert not (
sep_parallel_degree > 1 and context_parallel_degree > 1
), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}."

if self.sharding_parallel_degree == -1:
if len(self.sharding) > 0:
self.sharding_parallel_degree = world_size // (
tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
)

sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
if sharding_parallel_degree == 1 and len(self.sharding) > 0:
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")
self.sharding = []

self.data_parallel_degree = world_size // (
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
)

if expert_parallel_degree > 1:
moe_sharding_parallel_degree = world_size // (pipeline_parallel_degree * expert_parallel_degree)
assert (
self.expert_tensor_parallel_degree <= 1
), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
else:
moe_sharding_parallel_degree = 1
moe_sharding_parallel_degree = max(moe_sharding_parallel_degree, 1)
if moe_sharding_parallel_degree > 1 and self.data_parallel_degree > 1:
raise NotImplementedError(
f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. But got data_parallel_degree: {self.data_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}."
)

if sharding_parallel_degree > 1 and moe_sharding_parallel_degree > 1:
assert (
sharding_parallel_degree % moe_sharding_parallel_degree == 0
), f"sharding_parallel_degree should be divided by moe_sharding_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, moe_sharding_parallel_degree: {moe_sharding_parallel_degree}."

assert not (
self.data_parallel_degree > 1 and expert_parallel_degree > 1
), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}."

if (
sharding_parallel_degree > 1
or tensor_parallel_degree > 1
or pipeline_parallel_degree > 1
or self.sep_parallel_degree > 1
or self.context_parallel_degree > 1
or expert_parallel_degree > 1
or expert_tensor_parallel_degree > 1
):
self.use_hybrid_parallel = True
self.sharding_parallel_degree = sharding_parallel_degree
self.tensor_parallel_degree = tensor_parallel_degree
self.pipeline_parallel_degree = pipeline_parallel_degree
self.sep_parallel_degree = sep_parallel_degree
self.context_parallel_degree = context_parallel_degree
self.expert_parallel_degree = expert_parallel_degree
self.expert_tensor_parallel_degree = expert_tensor_parallel_degree
self.moe_sharding_parallel_degree = moe_sharding_parallel_degree

if not self.use_hybrid_parallel:
self.sharding = []
self.sharding_parallel_degree = -1
self.tensor_parallel_degree = -1
self.pipeline_parallel_degree = -1
self.sep_parallel_degree = -1
self.context_parallel_degree = -1
self.expert_parallel_degree = -1
self.expert_tensor_parallel_degree = -1

if self.hybrid_parallel_topo_order is None:
self.hybrid_parallel_topo_order = "sharding_first"
assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"]

if self.use_hybrid_parallel and self.enable_auto_parallel:
self.use_hybrid_parallel = False

if self.to_static:
assert world_size == 1 or self.enable_auto_parallel, (
Expand Down Expand Up @@ -1383,6 +1501,17 @@ def is_segment_parallel_supported():
logger.warning("segment parallel is not supported!!!, Ignore it.")
return support_sep

def is_context_parallel_supported():
import inspect

members = [
name for (name, date) in inspect.getmembers(fleet.base.topology.EPHybridCommunicateGroup)
]
support_cp = "get_context_parallel_world_size" in members
if not support_cp:
logger.warning("context parallel is not supported!!! Ignore it.")
return support_cp

if self.hybrid_parallel_topo_order == "pp_first":
if is_segment_parallel_supported():
order = ["dp", "pp", "sharding", "sep", "mp"]
Expand All @@ -1394,17 +1523,31 @@ def is_segment_parallel_supported():
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
order = order[1:-1] + ["dp", "mp"]
if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1:
if is_context_parallel_supported():
order = ["sharding", "moe_sharding", "pp", "sep", "cp", "dp", "ep", "mp"]
else:
order = ["sharding", "moe_sharding", "pp", "sep", "dp", "ep", "mp"]
else:
order = ["sharding", "pp", "sep", "dp", "mp"]

if is_segment_parallel_supported():
if is_context_parallel_supported():
hybrid_configs = {
"dp_degree": self.data_parallel_degree,
"mp_degree": self.tensor_parallel_degree,
"pp_degree": self.pipeline_parallel_degree,
"sharding_degree": self.sharding_parallel_degree,
"sep_degree": self.sep_parallel_degree
if self.sep_parallel_degree > 1
else self.context_parallel_degree,
"sep_degree": self.sep_parallel_degree,
"cp_degree": self.context_parallel_degree,
"order": order,
}
elif is_segment_parallel_supported():
hybrid_configs = {
"dp_degree": self.data_parallel_degree,
"mp_degree": self.tensor_parallel_degree,
"pp_degree": self.pipeline_parallel_degree,
"sharding_degree": self.sharding_parallel_degree,
"sep_degree": self.sep_parallel_degree,
"order": order,
}
else:
Expand All @@ -1416,6 +1559,13 @@ def is_segment_parallel_supported():
"order": order,
}

if self.expert_parallel_degree > 1:
assert (
self.use_expert_parallel is True and self.moe_sharding_parallel_degree >= 0
), f"invalid expert_parallel_degree {self.expert_parallel_degree} and use_expert_paralle:{self.use_expert_parallel}."
hybrid_configs["ep_degree"] = self.expert_parallel_degree
hybrid_configs["moe_sharding_degree"] = self.moe_sharding_parallel_degree

try:
if self.split_norm_comm:
hybrid_configs["split_norm_comm"] = True
Expand Down Expand Up @@ -2052,47 +2202,12 @@ def _post_init_parallel_degree(self):
self.use_hybrid_parallel = False

def add_moe_comm_group(self):
hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs
# NOTE(zhangweilong):move init_moe_group logic to paddle fleet.init
moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group()
moe_grad_group = fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group()
hcg = fleet.get_hybrid_communicate_group()
topo = hcg._topo
sharding_parallel_groups = topo.get_comm_list("sharding")
experts_replicas = self.sharding_parallel_degree // self.expert_parallel_degree

# init experts groups inside all sharding groups
for ranks_in_current_sharding_group in sharding_parallel_groups:
# init experts parallel groups (dispatch & combine)
for i in range(experts_replicas):
rank_indices = list(range(i * self.expert_parallel_degree, (i + 1) * self.expert_parallel_degree))
ranks = [ranks_in_current_sharding_group[i] for i in rank_indices]
if message2nccl_config is not None and hybrid_configs.get("ep_configs", None) is not None:
group = dist.new_group(
ranks=ranks, nccl_config=message2nccl_config(hybrid_configs["ep_configs"].nccl_config, "ep")
)
else:
group = dist.new_group(ranks=ranks)
if dist.get_rank() in ranks:
assert not hasattr(hcg, "expert_parallel_group"), "expert_parallel_group can not be set repeate"
setattr(hcg, "expert_parallel_group", group)

# init experts gradients comm groups
for i in range(self.expert_parallel_degree):
rank_indices = list(range(i, self.sharding_parallel_degree, self.expert_parallel_degree))
ranks = [ranks_in_current_sharding_group[i] for i in rank_indices]
if message2nccl_config is not None and hybrid_configs.get("ep_configs", None) is not None:
group = dist.new_group(
ranks=ranks,
nccl_config=message2nccl_config(hybrid_configs["ep_configs"].grad_nccl_config, "ep_grad"),
)
else:
group = dist.new_group(ranks=ranks)
if dist.get_rank() in ranks:
assert not hasattr(hcg, "expert_grad_comm_group"), "expert_grad_comm_group can not be set repeate"
setattr(hcg, "expert_grad_comm_group", group)

assert hasattr(hcg, "expert_parallel_group") and hasattr(hcg, "expert_grad_comm_group")
logger.info(
f"experts groups are created, expert_parallel_group: {hcg.expert_parallel_group}, expert_grad_comm_group: {hcg.expert_grad_comm_group}"
)
setattr(hcg, "expert_parallel_group", moe_group)
setattr(hcg, "expert_grad_comm_group", moe_grad_group)

def __str__(self):
self_as_dict = asdict(self)
Expand Down Expand Up @@ -2200,6 +2315,28 @@ def pipeline_parallel_rank(self):
else:
return 0

@property
def expert_parallel_rank(self):
if self.use_hybrid_parallel:
hcg = fleet.get_hybrid_communicate_group()
if hasattr(hcg, "get_expert_parallel_rank"):
return max(hcg.get_expert_parallel_rank(), 0)
else:
return 0
else:
return 0

@property
def context_parallel_rank(self):
if self.use_hybrid_parallel:
hcg = fleet.get_hybrid_communicate_group()
if hasattr(hcg, "get_context_parallel_rank"):
return max(hcg.get_context_parallel_rank(), 0)
else:
return 0
else:
return 0

def _format_name(self, prefix, rank, degree):
size = 2
return f"{prefix}{rank:0>{size}d}"
Expand All @@ -2214,7 +2351,7 @@ def optimizer_name_suffix(self):
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
if self.use_expert_parallel:
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
return "_".join(name)
else:
Expand All @@ -2230,7 +2367,7 @@ def weight_name_suffix(self):
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
if self.pipeline_parallel_degree > 1:
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_expert_parallel:
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))
return "_".join(name)

Expand All @@ -2239,7 +2376,9 @@ def weight_name_suffix(self):
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None, sharding_parallel_degree=None):
if sharding_parallel_degree is None:
sharding_parallel_degree = self.sharding_parallel_degree
if self.use_hybrid_parallel:
name = []
if self.tensor_parallel_degree > 1:
Expand All @@ -2249,12 +2388,12 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
pp_id = self.pipeline_parallel_rank
assert isinstance(pp_id, int)
name.append(self._format_name("pp", pp_id, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
if sharding_parallel_degree > 1:
if shard_id is None:
shard_id = self.sharding_parallel_rank
assert isinstance(shard_id, int)
name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree))
if self.use_expert_parallel:
name.append(self._format_name("shard", shard_id, sharding_parallel_degree))
if self.use_expert_parallel and self.expert_parallel_degree <= 1:
if moe_id is None:
moe_id = self.data_parallel_rank
assert isinstance(moe_id, int)
Expand Down
15 changes: 11 additions & 4 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,10 @@ def __init__(self, config: DeepseekV2Config):
)

# (LiuTing) only support either tp or ep.
moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group()
moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group()
expert_parallel_degree = dist.get_world_size(moe_group)
moe_grad_group = fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group()

expert_parallel_degree = 1 if expert_parallel_degree < 0 else expert_parallel_degree
act_tp_shard = config.tensor_parallel_degree > 1 and expert_parallel_degree <= 1
super().__init__(
Expand All @@ -800,7 +802,12 @@ def __init__(self, config: DeepseekV2Config):
},
gate=gate,
capacity=2.0,
moe_group="expert",
)

for p in self.experts.parameters():
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})

self.alpha = config.aux_loss_alpha
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
Expand Down Expand Up @@ -838,8 +845,8 @@ def __init__(self, config: DeepseekV2Config):
)

hcg = fleet.get_hybrid_communicate_group()
moe_group = hcg.expert_parallel_group
moe_grad_group = hcg.expert_grad_comm_group
moe_group = hcg.get_expert_parallel_group()
moe_grad_group = hcg.get_moe_sharding_parallel_group()

super().__init__(
config=config,
Expand Down Expand Up @@ -1467,7 +1474,7 @@ def get_tensor_parallel_split_mappings(num_layers):
base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False)

# moe unit routed experts
moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group()
moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group()
expert_parallel_degree = dist.get_world_size(moe_group)
if expert_parallel_degree <= 1:
for e_i in range(config.n_routed_experts):
Expand Down
Loading
Loading