31
31
import warnings
32
32
from collections import OrderedDict
33
33
from collections .abc import Mapping
34
+
34
35
from pathlib import Path
35
36
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
36
37
163
164
should_skip_data ,
164
165
speed_metrics ,
165
166
split_parallel_config ,
167
+ init_optimizer ,
166
168
)
167
169
from .training_args import TrainingArguments
168
170
from .unified_checkpoint import UnifiedCheckpointHandler
171
+ from .unified_checkpoint .utils import generate_base_static_name
172
+ from paddle .distributed .checkpoint .sharded_tensor import ShardedTensor , build_sharded_state_dict
169
173
from .utils import reshard as reshard_util
170
174
from .utils .async_save import AsyncSaver
171
175
197
201
if is_datasets_available ():
198
202
import datasets
199
203
200
-
201
204
try :
202
205
from paddle .distributed .fleet .utils import mix_precision_utils
203
206
except :
@@ -914,7 +917,7 @@ def train(
914
917
self ._memory_tracker .start ()
915
918
916
919
if not self .args .enable_auto_parallel :
917
- if not self .args .should_load_sharding_stage1_model :
920
+ if not self .args .should_load_sharding_stage1_model and not self . args . using_flex_checkpoint :
918
921
self ._load_from_checkpoint (resume_from_checkpoint )
919
922
920
923
if self .args .should_load_sharding_stage1_model :
@@ -934,14 +937,32 @@ def train(
934
937
if delay_optimizer_creation :
935
938
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
936
939
self ._load_optimizer_and_scheduler (resume_from_checkpoint )
937
- else :
940
+ elif not self . args . using_flex_checkpoint :
938
941
model = self ._wrap_model (self .model_wrapped )
939
942
# for the rest of this function `model` is the outside model, whether it was wrapped or not
940
943
if model is not self .model :
941
944
self .model_wrapped = model
942
945
if delay_optimizer_creation :
943
946
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
944
947
self ._load_optimizer_and_scheduler (resume_from_checkpoint )
948
+ else :
949
+ assert self .args .using_flex_checkpoint , "default using flex_checkpoint!"
950
+
951
+ model = self ._wrap_model (self .model_wrapped )
952
+ if model is not self .model :
953
+ self .model_wrapped = model
954
+
955
+ if delay_optimizer_creation :
956
+ self .create_optimizer_and_scheduler (num_training_steps = max_steps )
957
+
958
+ if resume_from_checkpoint is not None :
959
+ model_sharded_state_dict = self .model .sharded_state_dict ()
960
+ self .optimizer .sharded_state_dict (model_sharded_state_dict )
961
+ init_optimizer (self .optimizer )
962
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
963
+ sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
964
+ dist .load_state_dict (sharded_state_dict , resume_from_checkpoint )
965
+ self ._load_scheduler (resume_from_checkpoint )
945
966
else :
946
967
model = self .model_wrapped
947
968
if delay_optimizer_creation :
@@ -1342,6 +1363,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
1342
1363
logger .warning (
1343
1364
f"optimizer not run, scale_before: { scale_before_value [0 ]} , scale_after: { scale_after_value [0 ]} "
1344
1365
)
1366
+
1345
1367
elif isinstance (self .optimizer , HybridParallelOptimizer ):
1346
1368
self .optimizer ._step (parameters_list )
1347
1369
else :
@@ -1593,7 +1615,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
1593
1615
if num_steps == 0 :
1594
1616
logs ["loss" ] = 0.0
1595
1617
else :
1596
- logs ["loss" ] = round (tr_loss_scalar / num_steps , 8 )
1618
+ logs ["loss" ] = round (tr_loss_scalar / num_steps , 8 )
1597
1619
1598
1620
logs ["learning_rate" ] = float ("{0:.3e}" .format (self ._get_learning_rate ()))
1599
1621
logs ["global_step" ] = int (self .state .global_step )
@@ -1968,7 +1990,6 @@ def apply_decay_param_fun(x):
1968
1990
grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm ) if self .args .max_grad_norm > 0 else None ,
1969
1991
** optimizer_kwargs ,
1970
1992
)
1971
-
1972
1993
return self .optimizer
1973
1994
1974
1995
def _apply_to_optimizer (self , action ):
@@ -2212,7 +2233,7 @@ def _wrap_model(self, model, training=True):
2212
2233
in_tensor_parallel_mode = self .args .tensor_parallel_degree > 1
2213
2234
in_sep_parallel_mode = self .args .sep_parallel_degree > 1
2214
2235
in_cp_parallel_mode = self .args .context_parallel_degree > 1
2215
-
2236
+
2216
2237
# Multi-gpu training
2217
2238
if self .args .world_size > 1 and (not self .args .use_hybrid_parallel ):
2218
2239
# MOE use DDP to broadcaset parameters.
@@ -2234,7 +2255,6 @@ def _wrap_model(self, model, training=True):
2234
2255
mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype )
2235
2256
assert self .optimizer is not None , "optimizer is empty!"
2236
2257
self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2237
-
2238
2258
# Pipeline mode
2239
2259
if in_pipeline_parallel_mode :
2240
2260
if self .args .amp_master_grad :
@@ -2279,20 +2299,18 @@ def get_expected_keys(inputs, keys):
2279
2299
"Using default prepare pipeline inputs func, only support input_ids and labels as inputs."
2280
2300
)
2281
2301
model ._prepare_pipeline_inputs_func = _prepare_pipeline_inputs_func
2282
-
2302
+
2283
2303
assert self .optimizer is not None , "Pipeline mode need decorate optimizer, pelease init optimizer."
2284
2304
if self .args .amp_master_grad :
2285
2305
self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2286
2306
self .optimizer = fleet .distributed_optimizer (self .optimizer )
2287
-
2288
2307
if (
2289
2308
hasattr (self .args , "enable_sharding_comm_overlap" )
2290
2309
and self .args .enable_sharding_comm_overlap
2291
2310
and self .args .unified_checkpoint
2292
2311
and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
2293
2312
):
2294
2313
model .register_sharding_comm_overlap_hook (self .optimizer )
2295
-
2296
2314
# No pipeline mode, sharding only
2297
2315
if not in_pipeline_parallel_mode and in_sharding_parallel_mode :
2298
2316
# Sharded DDP!
@@ -2306,7 +2324,6 @@ def get_expected_keys(inputs, keys):
2306
2324
model = paddle .distributed .fleet .meta_parallel .TensorParallel (
2307
2325
model , hcg , strategy = fleet .fleet ._user_defined_strategy
2308
2326
)
2309
-
2310
2327
if ShardingOption .SHARD_OP in self .args .sharding :
2311
2328
if self .args .amp_master_grad :
2312
2329
mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype ) # return value has no use
@@ -2325,7 +2342,6 @@ def get_expected_keys(inputs, keys):
2325
2342
level = "p_g_os"
2326
2343
2327
2344
from paddle .distributed .sharding import group_sharded_parallel
2328
-
2329
2345
# add dp_group and exclude_layer params
2330
2346
# https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
2331
2347
extra_kwargs = {}
@@ -2348,6 +2364,7 @@ def get_expected_keys(inputs, keys):
2348
2364
offload = cpu_offload ,
2349
2365
** extra_kwargs ,
2350
2366
)
2367
+
2351
2368
if ShardingOption .SHARD_GRAD_OP in self .args .sharding and self .args .amp_master_grad :
2352
2369
assert hasattr (optimizer , "use_main_grad" ), (
2353
2370
"Current installed paddle doesn't support sharding stage 2 with main grad, "
@@ -2373,7 +2390,6 @@ def get_expected_keys(inputs, keys):
2373
2390
if self .args .amp_master_grad :
2374
2391
self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2375
2392
self .optimizer = fleet .distributed_optimizer (self .optimizer )
2376
-
2377
2393
# stage1 has v1 and v2 version
2378
2394
if in_sharding_parallel_mode and ShardingOption .SHARD_OP in self .args .sharding :
2379
2395
if "split_param" in self .args .sharding_parallel_config :
@@ -2388,7 +2404,6 @@ def get_expected_keys(inputs, keys):
2388
2404
and "enable_stage1_broadcast_overlap" in self .args .sharding_parallel_config
2389
2405
):
2390
2406
self .optimizer ._set_broadcast_overlap (True , model )
2391
-
2392
2407
return model
2393
2408
2394
2409
def _prepare_input (self , data : Union [paddle .Tensor , Any ]) -> Union [paddle .Tensor , Any ]:
@@ -2700,6 +2715,10 @@ def _save_checkpoint(self, model, metrics=None):
2700
2715
else :
2701
2716
self .save_model (output_dir )
2702
2717
2718
+ model_sharded_state_dict = self .model .sharded_state_dict ()
2719
+ if self .args .using_flex_checkpoint :
2720
+ os .makedirs (output_dir , exist_ok = True )
2721
+
2703
2722
# Determine the new best metric / best model checkpoint
2704
2723
if metrics is not None and self .args .metric_for_best_model is not None :
2705
2724
metric_to_check = self .args .metric_for_best_model
@@ -2763,23 +2782,30 @@ def _save_checkpoint(self, model, metrics=None):
2763
2782
signal_dir ,
2764
2783
)
2765
2784
else :
2766
- if self .dp_group .rank > 0 : # this should only work for MoE saving
2767
- self ._save_ckpt_func (
2768
- self ._filter_moe_no_sync_optimizer_params (),
2769
- os .path .join (output_dir , optimizer_name ),
2770
- saved_signal_path ,
2771
- )
2772
-
2773
- else :
2774
- state_dict = self .optimizer .state_dict ()
2775
- save_path = os .path .join (output_dir , optimizer_name )
2776
- if self .args .use_async_save :
2777
- assert not strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )), "Dont support FLAG_LLM_PDC"
2778
- self ._async_optimizer_saver .run (
2779
- state_dict , save_path , saved_signal_path = saved_signal_path
2785
+ if not self .args .using_flex_checkpoint :
2786
+ if self .dp_group .rank > 0 : # this should only work for MoE saving
2787
+ self ._save_ckpt_func (
2788
+ self ._filter_moe_no_sync_optimizer_params (),
2789
+ os .path .join (output_dir , optimizer_name ),
2790
+ saved_signal_path ,
2780
2791
)
2792
+
2781
2793
else :
2782
- self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2794
+ state_dict = self .optimizer .state_dict ()
2795
+ save_path = os .path .join (output_dir , optimizer_name )
2796
+ if self .args .use_async_save :
2797
+ assert not strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )), "Dont support FLAG_LLM_PDC"
2798
+ self ._async_optimizer_saver .run (
2799
+ state_dict , save_path , saved_signal_path = saved_signal_path
2800
+ )
2801
+ else :
2802
+ self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2803
+ else :
2804
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2805
+ dist .save_state_dict (
2806
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2807
+ output_dir ,
2808
+ )
2783
2809
else :
2784
2810
if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
2785
2811
global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
@@ -2800,7 +2826,7 @@ def _save_checkpoint(self, model, metrics=None):
2800
2826
output_dir ,
2801
2827
signal_dir ,
2802
2828
)
2803
- else :
2829
+ elif not self . args . using_flex_checkpoint :
2804
2830
if self .args .data_parallel_rank > 0 and self .args .use_expert_parallel :
2805
2831
self ._save_ckpt_func (
2806
2832
self ._filter_moe_no_sync_optimizer_params (),
@@ -2814,6 +2840,13 @@ def _save_checkpoint(self, model, metrics=None):
2814
2840
saved_signal_path ,
2815
2841
)
2816
2842
2843
+ else :
2844
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2845
+ dist .save_state_dict (
2846
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2847
+ output_dir ,
2848
+ )
2849
+
2817
2850
# FIXME: maybe only save one copy
2818
2851
paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
2819
2852
@@ -3077,13 +3110,32 @@ def _save(
3077
3110
with open (path , "w" ) as f :
3078
3111
json .dump (model_meta , f )
3079
3112
3113
+ def _load_scheduler (self , checkpoint ):
3114
+ if checkpoint is None :
3115
+ self .runtime_timer .stop ()
3116
+ return
3117
+
3118
+ if not self .args .ignore_load_lr_and_optim :
3119
+ if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3120
+ self .lr_scheduler .set_state_dict (
3121
+ paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3122
+ )
3123
+ else :
3124
+ raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3125
+
3126
+ if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3127
+ self .scaler .load_state_dict (
3128
+ paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3129
+ )
3130
+
3080
3131
def _load_optimizer_and_scheduler (self , checkpoint ):
3081
3132
"""If optimizer and scheduler states exist, load them."""
3082
3133
self .runtime_timer .start ("checkpoint loading time" )
3083
3134
if checkpoint is None :
3084
3135
self .runtime_timer .stop ()
3085
3136
return
3086
3137
3138
+
3087
3139
logger .info ("Loading optimizer and scheduler..." )
3088
3140
if (not self .args .should_load_sharding_stage1_model ) and self .args .ignore_load_lr_and_optim :
3089
3141
self .runtime_timer .stop ()
@@ -3118,6 +3170,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
3118
3170
and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
3119
3171
):
3120
3172
model = self .model_wrapped
3173
+
3121
3174
opt_state_dict = self .unified_checkpoint_handler .load_unified_optimizer (
3122
3175
model = model ,
3123
3176
optimizer = self .optimizer ,
@@ -3149,18 +3202,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
3149
3202
optimizer_name = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
3150
3203
raise ValueError (f"optimizer-state-dict not found, opt: { os .path .join (checkpoint , optimizer_name )} ." )
3151
3204
3152
- if not self .args .ignore_load_lr_and_optim :
3153
- if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3154
- self .lr_scheduler .set_state_dict (
3155
- paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3156
- )
3157
- else :
3158
- raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3159
-
3160
- if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3161
- self .scaler .load_state_dict (
3162
- paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3163
- )
3205
+ self ._load_scheduler (checkpoint )
3164
3206
3165
3207
if self .args .offload_optim :
3166
3208
logger .info ("Offloading optimizer state..." )
0 commit comments