@@ -411,28 +411,22 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
411
411
return state_dict
412
412
413
413
def _load_optimizer_state_of_one_shard (self , checkpoint , base_opt_name , optimizer_name_suffix , group_getter = None ):
414
- def load_impl (_base_opt_name ):
415
- optimizer_name = _add_variant (_base_opt_name , optimizer_name_suffix )
416
- path = os .path .join (checkpoint , optimizer_name )
417
- logger .info (f"load optimizer state from { path } " )
418
- if os .path .isfile (path ):
419
- return self ._remap_parameter_name (
420
- checkpoint ,
421
- self ._modify_ckpt_for_compatibility (paddlenlp_load (path , map_location = "cpu" )),
422
- is_opt = True ,
423
- )
424
- logger .info (f"{ path } not exists" )
425
- return None
426
-
427
- opt_state = load_impl (base_opt_name )
428
414
if self .is_ema :
429
- ema_opt_state = load_impl (base_opt_name .replace ("optimizer" , "ema" ))
430
- if ema_opt_state is not None :
431
- assert opt_state is not None , "optimizer state should exist when EMA optimizer state exists"
432
- opt_state ["master_weights" ] = ema_opt_state .pop ("master_weights" , {})
433
- else :
434
- assert opt_state is None , "optimizer state should not exist when EMA optimizer state does not exist"
435
- return opt_state
415
+ base_opt_name = base_opt_name .replace ("optimizer" , "ema" )
416
+ optimizer_name = _add_variant (base_opt_name , optimizer_name_suffix )
417
+ path = os .path .join (checkpoint , optimizer_name )
418
+ logger .info (f"load optimizer state from { path } " )
419
+ if os .path .isfile (path ):
420
+ opt_state = paddlenlp_load (path , map_location = "cpu" )
421
+ if self .is_ema :
422
+ opt_state = {"master_weights" : opt_state .get ("master_weights" , {})}
423
+ return self ._remap_parameter_name (
424
+ checkpoint ,
425
+ self ._modify_ckpt_for_compatibility (opt_state ),
426
+ is_opt = True ,
427
+ )
428
+ logger .info (f"{ path } not exists" )
429
+ return None
436
430
437
431
def _modify_ckpt_for_compatibility (self , ckpt ):
438
432
master_weights = ckpt .get ("master_weights" , None )
@@ -611,11 +605,7 @@ def reshard_sharding(node_model_state):
611
605
612
606
node_model_state = load_model_slices ()
613
607
node_model_state = reshard_pp (node_model_state )
614
- opt_state = reshard_sharding (node_model_state )
615
- if self .is_ema :
616
- return {"master_weights" : opt_state .get ("master_weights" , {})}
617
- else :
618
- return opt_state
608
+ return reshard_sharding (node_model_state )
619
609
620
610
def manipulate_state_dict_and_config (self , model_to_save , merge_tensor_parallel = False , state_dict = None ):
621
611
weight_name_suffix = self .args .sharded_name_suffix ()
0 commit comments