Skip to content

Commit 419b349

Browse files
committed
fix
1 parent 9805b27 commit 419b349

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@
166166
split_parallel_config,
167167
)
168168
from .training_args import TrainingArguments
169+
from .unified_checkpoint import UnifiedCheckpointHandler
169170
from .utils import reshard as reshard_util
170171
from .utils.async_save import AsyncSaver
171172

@@ -957,7 +958,7 @@ def train(
957958
init_optimizer(self.optimizer)
958959
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
959960
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
960-
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint)
961+
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config)
961962
self._load_scheduler(resume_from_checkpoint)
962963
else:
963964
model = self.model_wrapped

paddlenlp/trainer/training_args.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,10 @@ class TrainingArguments:
407407
Whether to release gradients during training. Default is `False`.
408408
ckpt_quant_stage (`str`, *optional*):
409409
Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0).
410+
using_flex_checkpoint(`bool`, *optional*):
411+
Whether to use FlexCheckpoint for save and load. Default is False.
412+
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
413+
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
410414
"""
411415

412416
output_dir: str = field(
@@ -1086,6 +1090,13 @@ class TrainingArguments:
10861090
default=None, metadata={"help": "NCCL中通信组的细粒度控制的配置文件路径, 默认值为None, 代表不启用此项配置"}
10871091
)
10881092

1093+
aoa_config: Optional[dict[str, list[str]]] = field(
1094+
default=None,
1095+
metadata={
1096+
"help": "The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None."
1097+
},
1098+
)
1099+
10891100
def __post_init__(self):
10901101
world_size = paddle.distributed.get_world_size()
10911102
if in_auto_parallel_align_mode():
@@ -2376,7 +2387,6 @@ def _no_sync_in_gradient_accumulation(self):
23762387

23772388
@property
23782389
def should_save_sharding_stage1_model(self):
2379-
# return True
23802390
if self.enable_auto_parallel:
23812391
return False
23822392
return (

paddlenlp/transformers/llama/modeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from paddle.distributed import fleet
3131
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3232
from paddle.distributed.fleet.recompute.recompute import recompute
33-
from paddle.distributed.flex_checkpoint import build_sharded_state_dict
33+
from paddle.distributed.flex_checkpoint.dcp.sharded_weight import (
34+
build_sharded_state_dict,
35+
)
3436

3537
from paddlenlp.transformers.refined_recompute import (
3638
RRColumnParallelLinear,
@@ -1427,7 +1429,6 @@ def get_tensor_parallel_split_mappings(num_layers):
14271429

14281430
@classmethod
14291431
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
1430-
raise NotImplementedError
14311432
# return parameter fuse utils
14321433
from paddlenlp.transformers.conversion_utils import split_or_fuse_func
14331434

0 commit comments

Comments
 (0)