@@ -270,7 +270,7 @@ def get_group_ids(self):
270
270
271
271
272
272
class ShardingIO :
273
- def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False ):
273
+ def __init__ (self , args , model , optimizer = None , hcg = None , remap_parameter_name = False , is_ema = False ):
274
274
self .args = args
275
275
self .model = model
276
276
self .optimizer = optimizer
@@ -282,6 +282,7 @@ def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=F
282
282
283
283
self .remap_parameter_name = remap_parameter_name
284
284
self .remapper = None
285
+ self .is_ema = is_ema
285
286
286
287
def _get_remapper (self , checkpoint ):
287
288
if not self .remap_parameter_name :
@@ -395,28 +396,43 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
395
396
"""
396
397
load state_dict of one shard from_checkpoint, Only load model state dict.
397
398
"""
399
+ if self .is_ema :
400
+ base_weight_name = base_weight_name .replace ("model_state" , "ema" ).replace ("pdparams" , "pdopt" )
398
401
file_path = os .path .join (resume_from_checkpoint , _add_variant (base_weight_name , weight_name_suffix ))
399
402
if not os .path .isfile (file_path ):
400
403
raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } , no { file_path } " )
401
404
402
405
logger .info (f"Loading model from { file_path } ." )
403
406
# We load the model state dict on the CPU to avoid an OOM error.
404
407
state_dict = paddle .load (file_path , return_numpy = True )
408
+ if self .is_ema :
409
+ state_dict .pop ("master_weights" , None )
405
410
state_dict = self ._remap_parameter_name (resume_from_checkpoint , state_dict , is_opt = False )
406
411
return state_dict
407
412
408
413
def _load_optimizer_state_of_one_shard (self , checkpoint , base_opt_name , optimizer_name_suffix , group_getter = None ):
409
- optimizer_name = _add_variant (base_opt_name , optimizer_name_suffix )
410
- path = os .path .join (checkpoint , optimizer_name )
411
- logger .info (f"load optimizer state from { path } " )
412
- if os .path .isfile (path ):
413
- return self ._remap_parameter_name (
414
- checkpoint ,
415
- self ._modify_ckpt_for_compatibility (paddlenlp_load (path , map_location = "cpu" )),
416
- is_opt = True ,
417
- )
418
- logger .info (f"{ path } not exists" )
419
- return None
414
+ def load_impl (_base_opt_name ):
415
+ optimizer_name = _add_variant (_base_opt_name , optimizer_name_suffix )
416
+ path = os .path .join (checkpoint , optimizer_name )
417
+ logger .info (f"load optimizer state from { path } " )
418
+ if os .path .isfile (path ):
419
+ return self ._remap_parameter_name (
420
+ checkpoint ,
421
+ self ._modify_ckpt_for_compatibility (paddlenlp_load (path , map_location = "cpu" )),
422
+ is_opt = True ,
423
+ )
424
+ logger .info (f"{ path } not exists" )
425
+ return None
426
+
427
+ opt_state = load_impl (base_opt_name )
428
+ if self .is_ema :
429
+ ema_opt_state = load_impl (base_opt_name .replace ("optimizer" , "ema" ))
430
+ if ema_opt_state is not None :
431
+ assert opt_state is not None , "optimizer state should exist when EMA optimizer state exists"
432
+ opt_state ["master_weights" ] = ema_opt_state .pop ("master_weights" , {})
433
+ else :
434
+ assert opt_state is None , "optimizer state should not exist when EMA optimizer state does not exist"
435
+ return opt_state
420
436
421
437
def _modify_ckpt_for_compatibility (self , ckpt ):
422
438
master_weights = ckpt .get ("master_weights" , None )
@@ -595,7 +611,11 @@ def reshard_sharding(node_model_state):
595
611
596
612
node_model_state = load_model_slices ()
597
613
node_model_state = reshard_pp (node_model_state )
598
- return reshard_sharding (node_model_state )
614
+ opt_state = reshard_sharding (node_model_state )
615
+ if self .is_ema :
616
+ return {"master_weights" : opt_state .get ("master_weights" , {})}
617
+ else :
618
+ return opt_state
599
619
600
620
def manipulate_state_dict_and_config (self , model_to_save , merge_tensor_parallel = False , state_dict = None ):
601
621
weight_name_suffix = self .args .sharded_name_suffix ()
0 commit comments