Skip to content

Commit 92bb80e

Browse files
committed
adapter flex_checkpoint
1 parent a205bc3 commit 92bb80e

File tree

4 files changed

+179
-43
lines changed

4 files changed

+179
-43
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import warnings
3232
from collections import OrderedDict
3333
from collections.abc import Mapping
34+
3435
from pathlib import Path
3536
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3637

@@ -163,9 +164,12 @@
163164
should_skip_data,
164165
speed_metrics,
165166
split_parallel_config,
167+
init_optimizer,
166168
)
167169
from .training_args import TrainingArguments
168170
from .unified_checkpoint import UnifiedCheckpointHandler
171+
from .unified_checkpoint.utils import generate_base_static_name
172+
from paddle.distributed.checkpoint.sharded_tensor import ShardedTensor, build_sharded_state_dict
169173
from .utils import reshard as reshard_util
170174
from .utils.async_save import AsyncSaver
171175

@@ -197,7 +201,6 @@
197201
if is_datasets_available():
198202
import datasets
199203

200-
201204
try:
202205
from paddle.distributed.fleet.utils import mix_precision_utils
203206
except:
@@ -914,7 +917,7 @@ def train(
914917
self._memory_tracker.start()
915918

916919
if not self.args.enable_auto_parallel:
917-
if not self.args.should_load_sharding_stage1_model:
920+
if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint:
918921
self._load_from_checkpoint(resume_from_checkpoint)
919922

920923
if self.args.should_load_sharding_stage1_model:
@@ -934,14 +937,32 @@ def train(
934937
if delay_optimizer_creation:
935938
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
936939
self._load_optimizer_and_scheduler(resume_from_checkpoint)
937-
else:
940+
elif not self.args.using_flex_checkpoint:
938941
model = self._wrap_model(self.model_wrapped)
939942
# for the rest of this function `model` is the outside model, whether it was wrapped or not
940943
if model is not self.model:
941944
self.model_wrapped = model
942945
if delay_optimizer_creation:
943946
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
944947
self._load_optimizer_and_scheduler(resume_from_checkpoint)
948+
else:
949+
assert self.args.using_flex_checkpoint, "default using flex_checkpoint!"
950+
951+
model = self._wrap_model(self.model_wrapped)
952+
if model is not self.model:
953+
self.model_wrapped = model
954+
955+
if delay_optimizer_creation:
956+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
957+
958+
if resume_from_checkpoint is not None:
959+
model_sharded_state_dict = self.model.sharded_state_dict()
960+
self.optimizer.sharded_state_dict(model_sharded_state_dict)
961+
init_optimizer(self.optimizer)
962+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
963+
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
964+
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint)
965+
self._load_scheduler(resume_from_checkpoint)
945966
else:
946967
model = self.model_wrapped
947968
if delay_optimizer_creation:
@@ -1342,6 +1363,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
13421363
logger.warning(
13431364
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
13441365
)
1366+
13451367
elif isinstance(self.optimizer, HybridParallelOptimizer):
13461368
self.optimizer._step(parameters_list)
13471369
else:
@@ -1593,7 +1615,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
15931615
if num_steps == 0:
15941616
logs["loss"] = 0.0
15951617
else:
1596-
logs["loss"] = round(tr_loss_scalar / num_steps, 8)
1618+
logs["loss"] = round(tr_loss_scalar / num_steps, 8)
15971619

15981620
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
15991621
logs["global_step"] = int(self.state.global_step)
@@ -1968,7 +1990,6 @@ def apply_decay_param_fun(x):
19681990
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
19691991
**optimizer_kwargs,
19701992
)
1971-
19721993
return self.optimizer
19731994

19741995
def _apply_to_optimizer(self, action):
@@ -2212,7 +2233,7 @@ def _wrap_model(self, model, training=True):
22122233
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
22132234
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
22142235
in_cp_parallel_mode = self.args.context_parallel_degree > 1
2215-
2236+
22162237
# Multi-gpu training
22172238
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
22182239
# MOE use DDP to broadcaset parameters.
@@ -2234,7 +2255,6 @@ def _wrap_model(self, model, training=True):
22342255
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
22352256
assert self.optimizer is not None, "optimizer is empty!"
22362257
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
2237-
22382258
# Pipeline mode
22392259
if in_pipeline_parallel_mode:
22402260
if self.args.amp_master_grad:
@@ -2279,20 +2299,18 @@ def get_expected_keys(inputs, keys):
22792299
"Using default prepare pipeline inputs func, only support input_ids and labels as inputs."
22802300
)
22812301
model._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func
2282-
2302+
22832303
assert self.optimizer is not None, "Pipeline mode need decorate optimizer, pelease init optimizer."
22842304
if self.args.amp_master_grad:
22852305
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
22862306
self.optimizer = fleet.distributed_optimizer(self.optimizer)
2287-
22882307
if (
22892308
hasattr(self.args, "enable_sharding_comm_overlap")
22902309
and self.args.enable_sharding_comm_overlap
22912310
and self.args.unified_checkpoint
22922311
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
22932312
):
22942313
model.register_sharding_comm_overlap_hook(self.optimizer)
2295-
22962314
# No pipeline mode, sharding only
22972315
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
22982316
# Sharded DDP!
@@ -2306,7 +2324,6 @@ def get_expected_keys(inputs, keys):
23062324
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
23072325
model, hcg, strategy=fleet.fleet._user_defined_strategy
23082326
)
2309-
23102327
if ShardingOption.SHARD_OP in self.args.sharding:
23112328
if self.args.amp_master_grad:
23122329
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
@@ -2325,7 +2342,6 @@ def get_expected_keys(inputs, keys):
23252342
level = "p_g_os"
23262343

23272344
from paddle.distributed.sharding import group_sharded_parallel
2328-
23292345
# add dp_group and exclude_layer params
23302346
# https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
23312347
extra_kwargs = {}
@@ -2348,6 +2364,7 @@ def get_expected_keys(inputs, keys):
23482364
offload=cpu_offload,
23492365
**extra_kwargs,
23502366
)
2367+
23512368
if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
23522369
assert hasattr(optimizer, "use_main_grad"), (
23532370
"Current installed paddle doesn't support sharding stage 2 with main grad, "
@@ -2373,7 +2390,6 @@ def get_expected_keys(inputs, keys):
23732390
if self.args.amp_master_grad:
23742391
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
23752392
self.optimizer = fleet.distributed_optimizer(self.optimizer)
2376-
23772393
# stage1 has v1 and v2 version
23782394
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
23792395
if "split_param" in self.args.sharding_parallel_config:
@@ -2388,7 +2404,6 @@ def get_expected_keys(inputs, keys):
23882404
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
23892405
):
23902406
self.optimizer._set_broadcast_overlap(True, model)
2391-
23922407
return model
23932408

23942409
def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
@@ -2700,6 +2715,10 @@ def _save_checkpoint(self, model, metrics=None):
27002715
else:
27012716
self.save_model(output_dir)
27022717

2718+
model_sharded_state_dict = self.model.sharded_state_dict()
2719+
if self.args.using_flex_checkpoint:
2720+
os.makedirs(output_dir, exist_ok=True)
2721+
27032722
# Determine the new best metric / best model checkpoint
27042723
if metrics is not None and self.args.metric_for_best_model is not None:
27052724
metric_to_check = self.args.metric_for_best_model
@@ -2763,23 +2782,30 @@ def _save_checkpoint(self, model, metrics=None):
27632782
signal_dir,
27642783
)
27652784
else:
2766-
if self.dp_group.rank > 0: # this should only work for MoE saving
2767-
self._save_ckpt_func(
2768-
self._filter_moe_no_sync_optimizer_params(),
2769-
os.path.join(output_dir, optimizer_name),
2770-
saved_signal_path,
2771-
)
2772-
2773-
else:
2774-
state_dict = self.optimizer.state_dict()
2775-
save_path = os.path.join(output_dir, optimizer_name)
2776-
if self.args.use_async_save:
2777-
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
2778-
self._async_optimizer_saver.run(
2779-
state_dict, save_path, saved_signal_path=saved_signal_path
2785+
if not self.args.using_flex_checkpoint:
2786+
if self.dp_group.rank > 0: # this should only work for MoE saving
2787+
self._save_ckpt_func(
2788+
self._filter_moe_no_sync_optimizer_params(),
2789+
os.path.join(output_dir, optimizer_name),
2790+
saved_signal_path,
27802791
)
2792+
27812793
else:
2782-
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2794+
state_dict = self.optimizer.state_dict()
2795+
save_path = os.path.join(output_dir, optimizer_name)
2796+
if self.args.use_async_save:
2797+
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
2798+
self._async_optimizer_saver.run(
2799+
state_dict, save_path, saved_signal_path=saved_signal_path
2800+
)
2801+
else:
2802+
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2803+
else:
2804+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2805+
dist.save_state_dict(
2806+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2807+
output_dir,
2808+
)
27832809
else:
27842810
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
27852811
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
@@ -2800,7 +2826,7 @@ def _save_checkpoint(self, model, metrics=None):
28002826
output_dir,
28012827
signal_dir,
28022828
)
2803-
else:
2829+
elif not self.args.using_flex_checkpoint:
28042830
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
28052831
self._save_ckpt_func(
28062832
self._filter_moe_no_sync_optimizer_params(),
@@ -2814,6 +2840,13 @@ def _save_checkpoint(self, model, metrics=None):
28142840
saved_signal_path,
28152841
)
28162842

2843+
else:
2844+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2845+
dist.save_state_dict(
2846+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2847+
output_dir,
2848+
)
2849+
28172850
# FIXME: maybe only save one copy
28182851
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
28192852

@@ -3077,13 +3110,32 @@ def _save(
30773110
with open(path, "w") as f:
30783111
json.dump(model_meta, f)
30793112

3113+
def _load_scheduler(self, checkpoint):
3114+
if checkpoint is None:
3115+
self.runtime_timer.stop()
3116+
return
3117+
3118+
if not self.args.ignore_load_lr_and_optim:
3119+
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
3120+
self.lr_scheduler.set_state_dict(
3121+
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
3122+
)
3123+
else:
3124+
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")
3125+
3126+
if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
3127+
self.scaler.load_state_dict(
3128+
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
3129+
)
3130+
30803131
def _load_optimizer_and_scheduler(self, checkpoint):
30813132
"""If optimizer and scheduler states exist, load them."""
30823133
self.runtime_timer.start("checkpoint loading time")
30833134
if checkpoint is None:
30843135
self.runtime_timer.stop()
30853136
return
30863137

3138+
30873139
logger.info("Loading optimizer and scheduler...")
30883140
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
30893141
self.runtime_timer.stop()
@@ -3118,6 +3170,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31183170
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
31193171
):
31203172
model = self.model_wrapped
3173+
31213174
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
31223175
model=model,
31233176
optimizer=self.optimizer,
@@ -3149,18 +3202,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31493202
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
31503203
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")
31513204

3152-
if not self.args.ignore_load_lr_and_optim:
3153-
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
3154-
self.lr_scheduler.set_state_dict(
3155-
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
3156-
)
3157-
else:
3158-
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")
3159-
3160-
if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
3161-
self.scaler.load_state_dict(
3162-
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
3163-
)
3205+
self._load_scheduler(checkpoint)
31643206

31653207
if self.args.offload_optim:
31663208
logger.info("Offloading optimizer state...")

0 commit comments

Comments
 (0)