Skip to content

Commit 9805b27

Browse files
committed
adapter flex_checkpoint
1 parent a205bc3 commit 9805b27

File tree

5 files changed

+178
-40
lines changed

5 files changed

+178
-40
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,13 @@
159159
get_last_checkpoint,
160160
get_scheduler,
161161
has_length,
162+
init_optimizer,
162163
set_seed,
163164
should_skip_data,
164165
speed_metrics,
165166
split_parallel_config,
166167
)
167168
from .training_args import TrainingArguments
168-
from .unified_checkpoint import UnifiedCheckpointHandler
169169
from .utils import reshard as reshard_util
170170
from .utils.async_save import AsyncSaver
171171

@@ -197,7 +197,6 @@
197197
if is_datasets_available():
198198
import datasets
199199

200-
201200
try:
202201
from paddle.distributed.fleet.utils import mix_precision_utils
203202
except:
@@ -914,7 +913,7 @@ def train(
914913
self._memory_tracker.start()
915914

916915
if not self.args.enable_auto_parallel:
917-
if not self.args.should_load_sharding_stage1_model:
916+
if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint:
918917
self._load_from_checkpoint(resume_from_checkpoint)
919918

920919
if self.args.should_load_sharding_stage1_model:
@@ -934,14 +933,32 @@ def train(
934933
if delay_optimizer_creation:
935934
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
936935
self._load_optimizer_and_scheduler(resume_from_checkpoint)
937-
else:
936+
elif not self.args.using_flex_checkpoint:
938937
model = self._wrap_model(self.model_wrapped)
939938
# for the rest of this function `model` is the outside model, whether it was wrapped or not
940939
if model is not self.model:
941940
self.model_wrapped = model
942941
if delay_optimizer_creation:
943942
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
944943
self._load_optimizer_and_scheduler(resume_from_checkpoint)
944+
else:
945+
assert self.args.using_flex_checkpoint, "default using flex_checkpoint!"
946+
947+
model = self._wrap_model(self.model_wrapped)
948+
if model is not self.model:
949+
self.model_wrapped = model
950+
951+
if delay_optimizer_creation:
952+
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
953+
954+
if resume_from_checkpoint is not None:
955+
model_sharded_state_dict = self.model.sharded_state_dict()
956+
self.optimizer.sharded_state_dict(model_sharded_state_dict)
957+
init_optimizer(self.optimizer)
958+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
959+
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
960+
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint)
961+
self._load_scheduler(resume_from_checkpoint)
945962
else:
946963
model = self.model_wrapped
947964
if delay_optimizer_creation:
@@ -1342,6 +1359,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
13421359
logger.warning(
13431360
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
13441361
)
1362+
13451363
elif isinstance(self.optimizer, HybridParallelOptimizer):
13461364
self.optimizer._step(parameters_list)
13471365
else:
@@ -1968,7 +1986,6 @@ def apply_decay_param_fun(x):
19681986
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
19691987
**optimizer_kwargs,
19701988
)
1971-
19721989
return self.optimizer
19731990

19741991
def _apply_to_optimizer(self, action):
@@ -2234,7 +2251,6 @@ def _wrap_model(self, model, training=True):
22342251
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
22352252
assert self.optimizer is not None, "optimizer is empty!"
22362253
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
2237-
22382254
# Pipeline mode
22392255
if in_pipeline_parallel_mode:
22402256
if self.args.amp_master_grad:
@@ -2284,15 +2300,13 @@ def get_expected_keys(inputs, keys):
22842300
if self.args.amp_master_grad:
22852301
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
22862302
self.optimizer = fleet.distributed_optimizer(self.optimizer)
2287-
22882303
if (
22892304
hasattr(self.args, "enable_sharding_comm_overlap")
22902305
and self.args.enable_sharding_comm_overlap
22912306
and self.args.unified_checkpoint
22922307
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
22932308
):
22942309
model.register_sharding_comm_overlap_hook(self.optimizer)
2295-
22962310
# No pipeline mode, sharding only
22972311
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
22982312
# Sharded DDP!
@@ -2306,7 +2320,6 @@ def get_expected_keys(inputs, keys):
23062320
model = paddle.distributed.fleet.meta_parallel.TensorParallel(
23072321
model, hcg, strategy=fleet.fleet._user_defined_strategy
23082322
)
2309-
23102323
if ShardingOption.SHARD_OP in self.args.sharding:
23112324
if self.args.amp_master_grad:
23122325
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
@@ -2348,6 +2361,7 @@ def get_expected_keys(inputs, keys):
23482361
offload=cpu_offload,
23492362
**extra_kwargs,
23502363
)
2364+
23512365
if ShardingOption.SHARD_GRAD_OP in self.args.sharding and self.args.amp_master_grad:
23522366
assert hasattr(optimizer, "use_main_grad"), (
23532367
"Current installed paddle doesn't support sharding stage 2 with main grad, "
@@ -2373,7 +2387,6 @@ def get_expected_keys(inputs, keys):
23732387
if self.args.amp_master_grad:
23742388
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
23752389
self.optimizer = fleet.distributed_optimizer(self.optimizer)
2376-
23772390
# stage1 has v1 and v2 version
23782391
if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding:
23792392
if "split_param" in self.args.sharding_parallel_config:
@@ -2388,7 +2401,6 @@ def get_expected_keys(inputs, keys):
23882401
and "enable_stage1_broadcast_overlap" in self.args.sharding_parallel_config
23892402
):
23902403
self.optimizer._set_broadcast_overlap(True, model)
2391-
23922404
return model
23932405

23942406
def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
@@ -2700,6 +2712,10 @@ def _save_checkpoint(self, model, metrics=None):
27002712
else:
27012713
self.save_model(output_dir)
27022714

2715+
model_sharded_state_dict = self.model.sharded_state_dict()
2716+
if self.args.using_flex_checkpoint:
2717+
os.makedirs(output_dir, exist_ok=True)
2718+
27032719
# Determine the new best metric / best model checkpoint
27042720
if metrics is not None and self.args.metric_for_best_model is not None:
27052721
metric_to_check = self.args.metric_for_best_model
@@ -2763,23 +2779,32 @@ def _save_checkpoint(self, model, metrics=None):
27632779
signal_dir,
27642780
)
27652781
else:
2766-
if self.dp_group.rank > 0: # this should only work for MoE saving
2767-
self._save_ckpt_func(
2768-
self._filter_moe_no_sync_optimizer_params(),
2769-
os.path.join(output_dir, optimizer_name),
2770-
saved_signal_path,
2771-
)
2772-
2773-
else:
2774-
state_dict = self.optimizer.state_dict()
2775-
save_path = os.path.join(output_dir, optimizer_name)
2776-
if self.args.use_async_save:
2777-
assert not strtobool(os.getenv("FLAG_LLM_PDC", "False")), "Dont support FLAG_LLM_PDC"
2778-
self._async_optimizer_saver.run(
2779-
state_dict, save_path, saved_signal_path=saved_signal_path
2782+
if not self.args.using_flex_checkpoint:
2783+
if self.dp_group.rank > 0: # this should only work for MoE saving
2784+
self._save_ckpt_func(
2785+
self._filter_moe_no_sync_optimizer_params(),
2786+
os.path.join(output_dir, optimizer_name),
2787+
saved_signal_path,
27802788
)
2789+
27812790
else:
2782-
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2791+
state_dict = self.optimizer.state_dict()
2792+
save_path = os.path.join(output_dir, optimizer_name)
2793+
if self.args.use_async_save:
2794+
assert not strtobool(
2795+
os.getenv("FLAG_LLM_PDC", "False")
2796+
), "Dont support FLAG_LLM_PDC"
2797+
self._async_optimizer_saver.run(
2798+
state_dict, save_path, saved_signal_path=saved_signal_path
2799+
)
2800+
else:
2801+
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2802+
else:
2803+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2804+
dist.save_state_dict(
2805+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2806+
output_dir,
2807+
)
27832808
else:
27842809
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
27852810
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
@@ -2800,7 +2825,7 @@ def _save_checkpoint(self, model, metrics=None):
28002825
output_dir,
28012826
signal_dir,
28022827
)
2803-
else:
2828+
elif not self.args.using_flex_checkpoint:
28042829
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
28052830
self._save_ckpt_func(
28062831
self._filter_moe_no_sync_optimizer_params(),
@@ -2814,6 +2839,13 @@ def _save_checkpoint(self, model, metrics=None):
28142839
saved_signal_path,
28152840
)
28162841

2842+
else:
2843+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2844+
dist.save_state_dict(
2845+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2846+
output_dir,
2847+
)
2848+
28172849
# FIXME: maybe only save one copy
28182850
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
28192851

@@ -3077,6 +3109,24 @@ def _save(
30773109
with open(path, "w") as f:
30783110
json.dump(model_meta, f)
30793111

3112+
def _load_scheduler(self, checkpoint):
3113+
if checkpoint is None:
3114+
self.runtime_timer.stop()
3115+
return
3116+
3117+
if not self.args.ignore_load_lr_and_optim:
3118+
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
3119+
self.lr_scheduler.set_state_dict(
3120+
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
3121+
)
3122+
else:
3123+
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")
3124+
3125+
if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
3126+
self.scaler.load_state_dict(
3127+
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
3128+
)
3129+
30803130
def _load_optimizer_and_scheduler(self, checkpoint):
30813131
"""If optimizer and scheduler states exist, load them."""
30823132
self.runtime_timer.start("checkpoint loading time")
@@ -3118,6 +3168,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31183168
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
31193169
):
31203170
model = self.model_wrapped
3171+
31213172
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
31223173
model=model,
31233174
optimizer=self.optimizer,
@@ -3149,18 +3200,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
31493200
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
31503201
raise ValueError(f"optimizer-state-dict not found, opt: {os.path.join(checkpoint, optimizer_name)}.")
31513202

3152-
if not self.args.ignore_load_lr_and_optim:
3153-
if distributed_isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
3154-
self.lr_scheduler.set_state_dict(
3155-
paddle.load(distributed_file(os.path.join(checkpoint, SCHEDULER_NAME)))
3156-
)
3157-
else:
3158-
raise ValueError(f"scheduler-file not found, scheduler:{os.path.join(checkpoint, SCHEDULER_NAME)}")
3159-
3160-
if self.do_grad_scaling and distributed_isfile(os.path.join(checkpoint, SCALER_NAME)):
3161-
self.scaler.load_state_dict(
3162-
paddle.load(distributed_file(os.path.join(checkpoint, SCALER_NAME)), return_numpy=True)
3163-
)
3203+
self._load_scheduler(checkpoint)
31643204

31653205
if self.args.offload_optim:
31663206
logger.info("Offloading optimizer state...")

paddlenlp/trainer/trainer_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import random
2929
import threading
3030
import time
31+
from collections import defaultdict
3132
from contextlib import contextmanager
3233
from enum import Enum
3334
from pathlib import Path
@@ -53,6 +54,21 @@
5354
from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
5455
from .utils.helper import distributed_file
5556

57+
try:
58+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
59+
DygraphShardingOptimizerV2,
60+
)
61+
except:
62+
DygraphShardingOptimizerV2 = None
63+
64+
try:
65+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
66+
DygraphShardingOptimizer,
67+
)
68+
except:
69+
DygraphShardingOptimizer = None
70+
71+
5672
__all__ = [
5773
"TrainOutput",
5874
"PredictionOutput",
@@ -1357,3 +1373,56 @@ def set_comm_config(configs, attr, dict_obj):
13571373
set_comm_config("moe_sharding_configs", "check_nccl_config", nccl_config.get("moe_sharding_check", None))
13581374
set_comm_config("default_comm_group_configs", "nccl_config", nccl_config.get("default", None))
13591375
return strategy
1376+
1377+
1378+
def init_optimizer(optimizer):
1379+
"""
1380+
Initialize the optimizer's states according to its type.
1381+
1382+
For DygraphShardingOptimizer (V1), initializes accumulators for local parameters.
1383+
For DygraphShardingOptimizerV2, manually initializes master weights and state dict for sharded parameters.
1384+
For other cases, initializes accumulators for all parameters.
1385+
1386+
Args:
1387+
optimizer: The optimizer instance to be initialized.
1388+
"""
1389+
if DygraphShardingOptimizer is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizer):
1390+
local_params = optimizer._rank2params[optimizer._sharding_rank]
1391+
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params)
1392+
return
1393+
1394+
elif DygraphShardingOptimizerV2 is not None and isinstance(optimizer._inner_opt, DygraphShardingOptimizerV2):
1395+
1396+
def init_param_optimizer_states(param_iter):
1397+
master_weights = {}
1398+
state_dict = {}
1399+
for static_name, shape in param_iter:
1400+
master_weights[static_name] = paddle.zeros(shape, dtype="float32")
1401+
for moment in ("moment1_0", "moment2_0"):
1402+
key = f"{static_name}_fp32_master_0_{moment}"
1403+
state_dict[key] = paddle.zeros(shape, dtype="float32")
1404+
for beta in ("beta1_pow_acc_0", "beta2_pow_acc_0"):
1405+
key = f"{static_name}_fp32_master_0_{beta}"
1406+
state_dict[key] = paddle.zeros((1,), dtype="float32")
1407+
return master_weights, state_dict
1408+
1409+
def buffer_params():
1410+
for buffer in optimizer._comm_buffer_list:
1411+
for param_name, grad_view in buffer._sharding_param_grad_view.items():
1412+
numel = grad_view._param.numel().item()
1413+
param_begin = grad_view._param_begin
1414+
param_end = grad_view._param_end
1415+
index = grad_view._index
1416+
padding_begin = index + numel
1417+
shape = (min(padding_begin, param_end) - param_begin,)
1418+
if shape[0] > 0:
1419+
yield param_name, shape
1420+
1421+
master_weights, state_dict = init_param_optimizer_states(buffer_params())
1422+
state_dict["master_weights"] = master_weights
1423+
state_dict["LR_Scheduler"] = {"last_epoch": 1, "last_lr": 5e-06}
1424+
optimizer.set_state_dict(state_dict)
1425+
return
1426+
optimizer._create_accumulators(
1427+
paddle.base.framework.default_main_program().global_block(), optimizer._parameter_list
1428+
)

paddlenlp/trainer/training_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,10 @@ class TrainingArguments:
921921
default=False,
922922
metadata={"help": "Whether to use async_save instead of paddle.save."},
923923
)
924+
using_flex_checkpoint: Optional[bool] = field(
925+
default=False,
926+
metadata={"help": "Whether use FlexCheckpoint."},
927+
)
924928
ordered_save_group_size: int = field(
925929
default=0,
926930
metadata={
@@ -2355,6 +2359,8 @@ def should_save_model_state(self):
23552359
return True
23562360
elif self.enable_auto_parallel:
23572361
return True
2362+
elif self.using_flex_checkpoint:
2363+
return False
23582364
elif self.use_hybrid_parallel:
23592365
# save on dataset rank 0
23602366
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
@@ -2370,6 +2376,7 @@ def _no_sync_in_gradient_accumulation(self):
23702376

23712377
@property
23722378
def should_save_sharding_stage1_model(self):
2379+
# return True
23732380
if self.enable_auto_parallel:
23742381
return False
23752382
return (

0 commit comments

Comments
 (0)