159
159
get_last_checkpoint ,
160
160
get_scheduler ,
161
161
has_length ,
162
+ init_optimizer ,
162
163
set_seed ,
163
164
should_skip_data ,
164
165
speed_metrics ,
165
166
split_parallel_config ,
166
167
)
167
168
from .training_args import TrainingArguments
168
- from .unified_checkpoint import UnifiedCheckpointHandler
169
169
from .utils import reshard as reshard_util
170
170
from .utils .async_save import AsyncSaver
171
171
197
197
if is_datasets_available ():
198
198
import datasets
199
199
200
-
201
200
try :
202
201
from paddle .distributed .fleet .utils import mix_precision_utils
203
202
except :
@@ -914,7 +913,7 @@ def train(
914
913
self ._memory_tracker .start ()
915
914
916
915
if not self .args .enable_auto_parallel :
917
- if not self .args .should_load_sharding_stage1_model :
916
+ if not self .args .should_load_sharding_stage1_model and not self . args . using_flex_checkpoint :
918
917
self ._load_from_checkpoint (resume_from_checkpoint )
919
918
920
919
if self .args .should_load_sharding_stage1_model :
@@ -934,14 +933,32 @@ def train(
934
933
if delay_optimizer_creation :
935
934
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
936
935
self ._load_optimizer_and_scheduler (resume_from_checkpoint )
937
- else :
936
+ elif not self . args . using_flex_checkpoint :
938
937
model = self ._wrap_model (self .model_wrapped )
939
938
# for the rest of this function `model` is the outside model, whether it was wrapped or not
940
939
if model is not self .model :
941
940
self .model_wrapped = model
942
941
if delay_optimizer_creation :
943
942
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
944
943
self ._load_optimizer_and_scheduler (resume_from_checkpoint )
944
+ else :
945
+ assert self .args .using_flex_checkpoint , "default using flex_checkpoint!"
946
+
947
+ model = self ._wrap_model (self .model_wrapped )
948
+ if model is not self .model :
949
+ self .model_wrapped = model
950
+
951
+ if delay_optimizer_creation :
952
+ self .create_optimizer_and_scheduler (num_training_steps = max_steps )
953
+
954
+ if resume_from_checkpoint is not None :
955
+ model_sharded_state_dict = self .model .sharded_state_dict ()
956
+ self .optimizer .sharded_state_dict (model_sharded_state_dict )
957
+ init_optimizer (self .optimizer )
958
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
959
+ sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
960
+ dist .load_state_dict (sharded_state_dict , resume_from_checkpoint )
961
+ self ._load_scheduler (resume_from_checkpoint )
945
962
else :
946
963
model = self .model_wrapped
947
964
if delay_optimizer_creation :
@@ -1342,6 +1359,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
1342
1359
logger .warning (
1343
1360
f"optimizer not run, scale_before: { scale_before_value [0 ]} , scale_after: { scale_after_value [0 ]} "
1344
1361
)
1362
+
1345
1363
elif isinstance (self .optimizer , HybridParallelOptimizer ):
1346
1364
self .optimizer ._step (parameters_list )
1347
1365
else :
@@ -1968,7 +1986,6 @@ def apply_decay_param_fun(x):
1968
1986
grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm ) if self .args .max_grad_norm > 0 else None ,
1969
1987
** optimizer_kwargs ,
1970
1988
)
1971
-
1972
1989
return self .optimizer
1973
1990
1974
1991
def _apply_to_optimizer (self , action ):
@@ -2234,7 +2251,6 @@ def _wrap_model(self, model, training=True):
2234
2251
mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype )
2235
2252
assert self .optimizer is not None , "optimizer is empty!"
2236
2253
self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2237
-
2238
2254
# Pipeline mode
2239
2255
if in_pipeline_parallel_mode :
2240
2256
if self .args .amp_master_grad :
@@ -2284,15 +2300,13 @@ def get_expected_keys(inputs, keys):
2284
2300
if self .args .amp_master_grad :
2285
2301
self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2286
2302
self .optimizer = fleet .distributed_optimizer (self .optimizer )
2287
-
2288
2303
if (
2289
2304
hasattr (self .args , "enable_sharding_comm_overlap" )
2290
2305
and self .args .enable_sharding_comm_overlap
2291
2306
and self .args .unified_checkpoint
2292
2307
and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
2293
2308
):
2294
2309
model .register_sharding_comm_overlap_hook (self .optimizer )
2295
-
2296
2310
# No pipeline mode, sharding only
2297
2311
if not in_pipeline_parallel_mode and in_sharding_parallel_mode :
2298
2312
# Sharded DDP!
@@ -2306,7 +2320,6 @@ def get_expected_keys(inputs, keys):
2306
2320
model = paddle .distributed .fleet .meta_parallel .TensorParallel (
2307
2321
model , hcg , strategy = fleet .fleet ._user_defined_strategy
2308
2322
)
2309
-
2310
2323
if ShardingOption .SHARD_OP in self .args .sharding :
2311
2324
if self .args .amp_master_grad :
2312
2325
mix_precision_utils .MixPrecisionLayer (model , dtype = self .amp_dtype ) # return value has no use
@@ -2348,6 +2361,7 @@ def get_expected_keys(inputs, keys):
2348
2361
offload = cpu_offload ,
2349
2362
** extra_kwargs ,
2350
2363
)
2364
+
2351
2365
if ShardingOption .SHARD_GRAD_OP in self .args .sharding and self .args .amp_master_grad :
2352
2366
assert hasattr (optimizer , "use_main_grad" ), (
2353
2367
"Current installed paddle doesn't support sharding stage 2 with main grad, "
@@ -2373,7 +2387,6 @@ def get_expected_keys(inputs, keys):
2373
2387
if self .args .amp_master_grad :
2374
2388
self .optimizer = mix_precision_utils .MixPrecisionOptimizer (self .optimizer )
2375
2389
self .optimizer = fleet .distributed_optimizer (self .optimizer )
2376
-
2377
2390
# stage1 has v1 and v2 version
2378
2391
if in_sharding_parallel_mode and ShardingOption .SHARD_OP in self .args .sharding :
2379
2392
if "split_param" in self .args .sharding_parallel_config :
@@ -2388,7 +2401,6 @@ def get_expected_keys(inputs, keys):
2388
2401
and "enable_stage1_broadcast_overlap" in self .args .sharding_parallel_config
2389
2402
):
2390
2403
self .optimizer ._set_broadcast_overlap (True , model )
2391
-
2392
2404
return model
2393
2405
2394
2406
def _prepare_input (self , data : Union [paddle .Tensor , Any ]) -> Union [paddle .Tensor , Any ]:
@@ -2700,6 +2712,10 @@ def _save_checkpoint(self, model, metrics=None):
2700
2712
else :
2701
2713
self .save_model (output_dir )
2702
2714
2715
+ model_sharded_state_dict = self .model .sharded_state_dict ()
2716
+ if self .args .using_flex_checkpoint :
2717
+ os .makedirs (output_dir , exist_ok = True )
2718
+
2703
2719
# Determine the new best metric / best model checkpoint
2704
2720
if metrics is not None and self .args .metric_for_best_model is not None :
2705
2721
metric_to_check = self .args .metric_for_best_model
@@ -2763,23 +2779,32 @@ def _save_checkpoint(self, model, metrics=None):
2763
2779
signal_dir ,
2764
2780
)
2765
2781
else :
2766
- if self .dp_group .rank > 0 : # this should only work for MoE saving
2767
- self ._save_ckpt_func (
2768
- self ._filter_moe_no_sync_optimizer_params (),
2769
- os .path .join (output_dir , optimizer_name ),
2770
- saved_signal_path ,
2771
- )
2772
-
2773
- else :
2774
- state_dict = self .optimizer .state_dict ()
2775
- save_path = os .path .join (output_dir , optimizer_name )
2776
- if self .args .use_async_save :
2777
- assert not strtobool (os .getenv ("FLAG_LLM_PDC" , "False" )), "Dont support FLAG_LLM_PDC"
2778
- self ._async_optimizer_saver .run (
2779
- state_dict , save_path , saved_signal_path = saved_signal_path
2782
+ if not self .args .using_flex_checkpoint :
2783
+ if self .dp_group .rank > 0 : # this should only work for MoE saving
2784
+ self ._save_ckpt_func (
2785
+ self ._filter_moe_no_sync_optimizer_params (),
2786
+ os .path .join (output_dir , optimizer_name ),
2787
+ saved_signal_path ,
2780
2788
)
2789
+
2781
2790
else :
2782
- self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2791
+ state_dict = self .optimizer .state_dict ()
2792
+ save_path = os .path .join (output_dir , optimizer_name )
2793
+ if self .args .use_async_save :
2794
+ assert not strtobool (
2795
+ os .getenv ("FLAG_LLM_PDC" , "False" )
2796
+ ), "Dont support FLAG_LLM_PDC"
2797
+ self ._async_optimizer_saver .run (
2798
+ state_dict , save_path , saved_signal_path = saved_signal_path
2799
+ )
2800
+ else :
2801
+ self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2802
+ else :
2803
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2804
+ dist .save_state_dict (
2805
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2806
+ output_dir ,
2807
+ )
2783
2808
else :
2784
2809
if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
2785
2810
global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
@@ -2800,7 +2825,7 @@ def _save_checkpoint(self, model, metrics=None):
2800
2825
output_dir ,
2801
2826
signal_dir ,
2802
2827
)
2803
- else :
2828
+ elif not self . args . using_flex_checkpoint :
2804
2829
if self .args .data_parallel_rank > 0 and self .args .use_expert_parallel :
2805
2830
self ._save_ckpt_func (
2806
2831
self ._filter_moe_no_sync_optimizer_params (),
@@ -2814,6 +2839,13 @@ def _save_checkpoint(self, model, metrics=None):
2814
2839
saved_signal_path ,
2815
2840
)
2816
2841
2842
+ else :
2843
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2844
+ dist .save_state_dict (
2845
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2846
+ output_dir ,
2847
+ )
2848
+
2817
2849
# FIXME: maybe only save one copy
2818
2850
paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
2819
2851
@@ -3077,6 +3109,24 @@ def _save(
3077
3109
with open (path , "w" ) as f :
3078
3110
json .dump (model_meta , f )
3079
3111
3112
+ def _load_scheduler (self , checkpoint ):
3113
+ if checkpoint is None :
3114
+ self .runtime_timer .stop ()
3115
+ return
3116
+
3117
+ if not self .args .ignore_load_lr_and_optim :
3118
+ if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3119
+ self .lr_scheduler .set_state_dict (
3120
+ paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3121
+ )
3122
+ else :
3123
+ raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3124
+
3125
+ if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3126
+ self .scaler .load_state_dict (
3127
+ paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3128
+ )
3129
+
3080
3130
def _load_optimizer_and_scheduler (self , checkpoint ):
3081
3131
"""If optimizer and scheduler states exist, load them."""
3082
3132
self .runtime_timer .start ("checkpoint loading time" )
@@ -3118,6 +3168,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
3118
3168
and "split_param" in split_parallel_config (self .args .sharding_parallel_config )
3119
3169
):
3120
3170
model = self .model_wrapped
3171
+
3121
3172
opt_state_dict = self .unified_checkpoint_handler .load_unified_optimizer (
3122
3173
model = model ,
3123
3174
optimizer = self .optimizer ,
@@ -3149,18 +3200,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
3149
3200
optimizer_name = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
3150
3201
raise ValueError (f"optimizer-state-dict not found, opt: { os .path .join (checkpoint , optimizer_name )} ." )
3151
3202
3152
- if not self .args .ignore_load_lr_and_optim :
3153
- if distributed_isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
3154
- self .lr_scheduler .set_state_dict (
3155
- paddle .load (distributed_file (os .path .join (checkpoint , SCHEDULER_NAME )))
3156
- )
3157
- else :
3158
- raise ValueError (f"scheduler-file not found, scheduler:{ os .path .join (checkpoint , SCHEDULER_NAME )} " )
3159
-
3160
- if self .do_grad_scaling and distributed_isfile (os .path .join (checkpoint , SCALER_NAME )):
3161
- self .scaler .load_state_dict (
3162
- paddle .load (distributed_file (os .path .join (checkpoint , SCALER_NAME )), return_numpy = True )
3163
- )
3203
+ self ._load_scheduler (checkpoint )
3164
3204
3165
3205
if self .args .offload_optim :
3166
3206
logger .info ("Offloading optimizer state..." )
0 commit comments