Skip to content

Commit 3a5acd2

Browse files
authored
Support to load sharded EMA checkpoint without loading non-EMA checkpoints (#11076)
* support_ema_loading_no_pdopt * polish code
1 parent d324fea commit 3a5acd2

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

paddlenlp/trainer/utils/reshard/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def convert_opt_name_to_tname(tensor_names, opt_names):
102102
opt_to_t[t] = t[: -len(s)]
103103
_find = True
104104
break
105-
assert _find
105+
assert _find, t
106106
return opt_to_t
107107

108108

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -411,28 +411,22 @@ def _load_one_state_dict_from_checkpoint(self, resume_from_checkpoint, base_weig
411411
return state_dict
412412

413413
def _load_optimizer_state_of_one_shard(self, checkpoint, base_opt_name, optimizer_name_suffix, group_getter=None):
414-
def load_impl(_base_opt_name):
415-
optimizer_name = _add_variant(_base_opt_name, optimizer_name_suffix)
416-
path = os.path.join(checkpoint, optimizer_name)
417-
logger.info(f"load optimizer state from {path}")
418-
if os.path.isfile(path):
419-
return self._remap_parameter_name(
420-
checkpoint,
421-
self._modify_ckpt_for_compatibility(paddlenlp_load(path, map_location="cpu")),
422-
is_opt=True,
423-
)
424-
logger.info(f"{path} not exists")
425-
return None
426-
427-
opt_state = load_impl(base_opt_name)
428414
if self.is_ema:
429-
ema_opt_state = load_impl(base_opt_name.replace("optimizer", "ema"))
430-
if ema_opt_state is not None:
431-
assert opt_state is not None, "optimizer state should exist when EMA optimizer state exists"
432-
opt_state["master_weights"] = ema_opt_state.pop("master_weights", {})
433-
else:
434-
assert opt_state is None, "optimizer state should not exist when EMA optimizer state does not exist"
435-
return opt_state
415+
base_opt_name = base_opt_name.replace("optimizer", "ema")
416+
optimizer_name = _add_variant(base_opt_name, optimizer_name_suffix)
417+
path = os.path.join(checkpoint, optimizer_name)
418+
logger.info(f"load optimizer state from {path}")
419+
if os.path.isfile(path):
420+
opt_state = paddlenlp_load(path, map_location="cpu")
421+
if self.is_ema:
422+
opt_state = {"master_weights": opt_state.get("master_weights", {})}
423+
return self._remap_parameter_name(
424+
checkpoint,
425+
self._modify_ckpt_for_compatibility(opt_state),
426+
is_opt=True,
427+
)
428+
logger.info(f"{path} not exists")
429+
return None
436430

437431
def _modify_ckpt_for_compatibility(self, ckpt):
438432
master_weights = ckpt.get("master_weights", None)
@@ -611,11 +605,7 @@ def reshard_sharding(node_model_state):
611605

612606
node_model_state = load_model_slices()
613607
node_model_state = reshard_pp(node_model_state)
614-
opt_state = reshard_sharding(node_model_state)
615-
if self.is_ema:
616-
return {"master_weights": opt_state.get("master_weights", {})}
617-
else:
618-
return opt_state
608+
return reshard_sharding(node_model_state)
619609

620610
def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel=False, state_dict=None):
621611
weight_name_suffix = self.args.sharded_name_suffix()

0 commit comments

Comments
 (0)