Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 79 additions & 38 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
get_last_checkpoint,
get_scheduler,
has_length,
init_optimizer,
set_seed,
should_skip_data,
speed_metrics,
Expand Down Expand Up @@ -197,7 +198,6 @@
if is_datasets_available():
import datasets


try:
from paddle.distributed.fleet.utils import mix_precision_utils
except:
Expand Down Expand Up @@ -914,7 +914,7 @@ def train(
self._memory_tracker.start()

if not self.args.enable_auto_parallel:
if not self.args.should_load_sharding_stage1_model:
if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint:
self._load_from_checkpoint(resume_from_checkpoint)

if self.args.should_load_sharding_stage1_model:
Expand All @@ -934,14 +934,32 @@ def train(
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
elif not self.args.using_flex_checkpoint:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
assert self.args.using_flex_checkpoint, "default using flex_checkpoint!"

model = self._wrap_model(self.model_wrapped)
if model is not self.model:
self.model_wrapped = model

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
model_sharded_state_dict = self.model.sharded_state_dict()
self.optimizer.sharded_state_dict(model_sharded_state_dict)
init_optimizer(self.optimizer)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config)
self._load_scheduler(resume_from_checkpoint)
else:
model = self.model_wrapped
if delay_optimizer_creation:
Expand Down Expand Up @@ -1342,6 +1360,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
logger.warning(
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
)

elif isinstance(self.optimizer, HybridParallelOptimizer):
self.optimizer._step(parameters_list)
else:
Expand Down Expand Up @@ -1968,7 +1987,6 @@ def apply_decay_param_fun(x):
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
**optimizer_kwargs,
)

return self.optimizer

def _apply_to_optimizer(self, action):
Expand Down Expand Up @@ -2234,7 +2252,6 @@ def _wrap_model(self, model, training=True):
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
assert self.optimizer is not None, "optimizer is empty!"
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)

# Pipeline mode
if in_pipeline_parallel_mode:
if self.args.amp_master_grad:
Expand Down Expand Up @@ -2284,15 +2301,13 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if (
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and self.args.unified_checkpoint
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model.register_sharding_comm_overlap_hook(self.optimizer)

# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
# Sharded DDP!
Expand All @@ -2306,7 +2321,6 @@ def get_expected_keys(inputs, keys):
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
model, hcg, strategy=fleet.fleet._user_defined_strategy
)

if ShardingOption.SHARD_OP in self.args.sharding:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
Expand Down Expand Up @@ -2348,6 +2362,7 @@ def get_expected_keys(inputs, keys):
offload=cpu_offload,
**extra_kwargs,
)

if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
assert hasattr(optimizer, "use_main_grad"), (
"Current installed paddle doesn't support sharding stage 2 with main grad, "
Expand All @@ -2373,7 +2388,6 @@ def get_expected_keys(inputs, keys):
if self.args.amp_master_grad:
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

# stage1 has v1 and v2 version
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
if "split_param" in self.args.sharding_parallel_config:
Expand All @@ -2388,7 +2402,6 @@ def get_expected_keys(inputs, keys):
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
):
self.optimizer._set_broadcast_overlap(True, model)

return model

def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
Expand Down Expand Up @@ -2700,6 +2713,10 @@ def _save_checkpoint(self, model, metrics=None):
else:
self.save_model(output_dir)

model_sharded_state_dict = self.model.sharded_state_dict()
if self.args.using_flex_checkpoint:
os.makedirs(output_dir, exist_ok=True)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
Expand Down Expand Up @@ -2763,23 +2780,32 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir,
)
else:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
if not self.args.using_flex_checkpoint:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
saved_signal_path,
)

else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
state_dict = self.optimizer.state_dict()
save_path = os.path.join(output_dir, optimizer_name)
if self.args.use_async_save:
assert not strtobool(
os.getenv("FLAG_LLM_PDC", "False")
), "Dont support FLAG_LLM_PDC"
self._async_optimizer_saver.run(
state_dict, save_path, saved_signal_path=saved_signal_path
)
else:
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
else:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
Expand All @@ -2800,7 +2826,7 @@ def _save_checkpoint(self, model, metrics=None):
output_dir,
signal_dir,
)
else:
elif not self.args.using_flex_checkpoint:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
Expand All @@ -2814,6 +2840,13 @@ def _save_checkpoint(self, model, metrics=None):
saved_signal_path,
)

else:
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)

# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))

Expand Down Expand Up @@ -3077,6 +3110,24 @@ def _save(
with open(path, "w") as f:
json.dump(model_meta, f)

def _load_scheduler(self, checkpoint):
if checkpoint is None:
self.runtime_timer.stop()
return

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)

def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
self.runtime_timer.start("checkpoint loading time")
Expand Down Expand Up @@ -3118,6 +3169,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model = self.model_wrapped

opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
model=model,
optimizer=self.optimizer,
Expand Down Expand Up @@ -3149,18 +3201,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")

if not self.args.ignore_load_lr_and_optim:
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
self.lr_scheduler.set_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
)
else:
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")

if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
)
self._load_scheduler(checkpoint)

if self.args.offload_optim:
logger.info("Offloading optimizer state...")
Expand Down
75 changes: 75 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@
from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
from .utils.helper import distributed_file

try:
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizerV2,
)
except:
DygraphShardingOptimizerV2 = None

try:
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
except:
DygraphShardingOptimizer = None


__all__ = [
"TrainOutput",
"PredictionOutput",
Expand Down Expand Up @@ -1357,3 +1372,63 @@ def set_comm_config(configs, attr, dict_obj):
set_comm_config("moe_sharding_configs", "check_nccl_config", nccl_config.get("moe_sharding_check", None))
set_comm_config("default_comm_group_configs", "nccl_config", nccl_config.get("default", None))
return strategy


def init_optimizer(optimizer):
"""
Initialize the optimizer's states according to its type.

For DygraphShardingOptimizer (V1), initializes accumulators for local parameters.
For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters.
For other cases, initializes accumulators for all parameters.

Args:
optimizer: The optimizer instance to be initialized.
"""
if DygraphShardingOptimizer is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizer):
local_params = optimizer._rank2params[optimizer._sharding_rank]
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params)
return

elif DygraphShardingOptimizerV2 is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizerV2):

def init_param_optimizer_states(param_iter):
master_weights = {}
state_dict = {}
moments = ("moment1_0", "moment2_0")
betas = ("beta1_pow_acc_0", "beta2_pow_acc_0")
for static_name, shape, no_need_master_weights in param_iter:
if not no_need_master_weights:
master_weights[static_name] = paddle.zeros(shape, dtype="float32")
prefix = f"{static_name}_fp32_master_0_"
else:
prefix = f"{static_name}_"

for moment in moments:
key = f"{prefix}{moment}"
state_dict[key] = paddle.zeros(shape, dtype="float32")
for beta in betas:
key = f"{prefix}{beta}"
state_dict[key] = paddle.zeros((1,), dtype="float32")
return master_weights, state_dict

def buffer_params():
for buffer in optimizer._comm_buffer_list:
for param_name, grad_view in buffer._sharding_param_grad_view.items():
param_begin = grad_view._param_begin
param_end = grad_view._param_end
shape = (param_end - param_begin,)
no_need_master_weights = grad_view._param.dtype == paddle.float32

if shape[0] > 0:
yield param_name, shape, no_need_master_weights

master_weights, state_dict = init_param_optimizer_states(buffer_params())
state_dict["master_weights"] = master_weights
state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06}

optimizer.set_state_dict(state_dict)
return
optimizer._create_accumulators(
paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list
)
17 changes: 17 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ class TrainingArguments:
Whether to release gradients during training. Default is `False`.
ckpt_quant_stage (`str`, *optional*):
Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0).
using_flex_checkpoint(`bool`, *optional*):
Whether to use FlexCheckpoint for save and load. Default is False.
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
"""

output_dir: str = field(
Expand Down Expand Up @@ -921,6 +925,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether to use async_save instead of paddle.save."},
)
using_flex_checkpoint: Optional[bool] = field(
default=False,
metadata={"help": "Whether use FlexCheckpoint."},
)
ordered_save_group_size: int = field(
default=0,
metadata={
Expand Down Expand Up @@ -1082,6 +1090,13 @@ class TrainingArguments:
default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"}
)

aoa_config: Optional[dict[str, list[str]]] = field(
default=None,
metadata={
"help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None."
},
)

def __post_init__(self):
world_size = paddle.distributed.get_world_size()
if in_auto_parallel_align_mode():
Expand Down Expand Up @@ -2355,6 +2370,8 @@ def should_save_model_state(self):
return True
elif self.enable_auto_parallel:
return True
elif self.using_flex_checkpoint:
return False
elif self.use_hybrid_parallel:
# save on dataset rank 0
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
Expand Down
Loading
Loading