Skip to content

Commit d324fea

Browse files
authored
support load sharded EMA checkpoints (#11073)
1 parent 133f0d2 commit d324fea

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def __init__(
377377
self.model,
378378
self.optimizer,
379379
remap_parameter_name=self.args.load_sharded_model_remap_parameter_name,
380+
is_ema=self.args.sharded_model_from_ema,
380381
)
381382

382383
if self.args.unified_checkpoint:

paddlenlp/trainer/training_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,11 @@ class TrainingArguments:
638638
metadata={"help": "Whether to remap parameter name when load_sharded_model = true."},
639639
)
640640

641+
sharded_model_from_ema: bool = field(
642+
default=False,
643+
metadata={"help": "Whether to load sharded model from EMA."},
644+
)
645+
641646
tensor_parallel_degree: int = field(
642647
default=-1,
643648
metadata={

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def get_group_ids(self):
270270

271271

272272
class ShardingIO:
273-
def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=False):
273+
def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=False, is_ema=False):
274274
self.args = args
275275
self.model = model
276276
self.optimizer = optimizer
@@ -282,6 +282,7 @@ def __init__(self, args, model, optimizer=None, hcg=None, remap_parameter_name=F
282282

283283
self.remap_parameter_name = remap_parameter_name
284284
self.remapper = None
285+
self.is_ema = is_ema
285286

286287
def _get_remapper(self, checkpoint):
287288
if not self.remap_parameter_name:
@@ -395,28 +396,43 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
395396
"""
396397
load state_dict of one shard from_checkpoint, Only load model state dict.
397398
"""
399+
if self.is_ema:
400+
base_weight_name = base_weight_name.replace("model_state", "ema").replace("pdparams", "pdopt")
398401
file_path = os.path.join(resume_from_checkpoint, _add_variant(base_weight_name, weight_name_suffix))
399402
if not os.path.isfile(file_path):
400403
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}, no {file_path}")
401404

402405
logger.info(f"Loading model from {file_path}.")
403406
# We load the model state dict on the CPU to avoid an OOM error.
404407
state_dict = paddle.load(file_path, return_numpy=True)
408+
if self.is_ema:
409+
state_dict.pop("master_weights", None)
405410
state_dict = self._remap_parameter_name(resume_from_checkpoint, state_dict, is_opt=False)
406411
return state_dict
407412

408413
def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimizer_name_suffix, group_getter=None):
409-
optimizer_name = _add_variant(base_opt_name, optimizer_name_suffix)
410-
path = os.path.join(checkpoint, optimizer_name)
411-
logger.info(f"load optimizer state from {path}")
412-
if os.path.isfile(path):
413-
return self._remap_parameter_name(
414-
checkpoint,
415-
self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu")),
416-
is_opt=True,
417-
)
418-
logger.info(f"{path} not exists")
419-
return None
414+
def load_impl(_base_opt_name):
415+
optimizer_name = _add_variant(_base_opt_name, optimizer_name_suffix)
416+
path = os.path.join(checkpoint, optimizer_name)
417+
logger.info(f"load optimizer state from {path}")
418+
if os.path.isfile(path):
419+
return self._remap_parameter_name(
420+
checkpoint,
421+
self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu")),
422+
is_opt=True,
423+
)
424+
logger.info(f"{path} not exists")
425+
return None
426+
427+
opt_state = load_impl(base_opt_name)
428+
if self.is_ema:
429+
ema_opt_state = load_impl(base_opt_name.replace("optimizer", "ema"))
430+
if ema_opt_state is not None:
431+
assert opt_state is not None, "optimizer state should exist when EMA optimizer state exists"
432+
opt_state["master_weights"] = ema_opt_state.pop("master_weights", {})
433+
else:
434+
assert opt_state is None, "optimizer state should not exist when EMA optimizer state does not exist"
435+
return opt_state
420436

421437
def _modify_ckpt_for_compatibility(self, ckpt):
422438
master_weights = ckpt.get("master_weights", None)
@@ -595,7 +611,11 @@ def reshard_sharding(node_model_state):
595611

596612
node_model_state = load_model_slices()
597613
node_model_state = reshard_pp(node_model_state)
598-
return reshard_sharding(node_model_state)
614+
opt_state = reshard_sharding(node_model_state)
615+
if self.is_ema:
616+
return {"master_weights": opt_state.get("master_weights", {})}
617+
else:
618+
return opt_state
599619

600620
def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False, state_dict=None):
601621
weight_name_suffix = self.args.sharded_name_suffix()

0 commit comments

Comments
 (0)