From da68451bde8fa741c58e134415093474b3414af3 Mon Sep 17 00:00:00 2001 From: shangjunyuan Date: Mon, 23 Jun 2025 22:50:47 +0800 Subject: [PATCH] [cherry-pick] add zcc_ema_loss_threshold args to avoid merging models with loss spike --- paddlenlp/trainer/trainer.py | 10 +++++++ paddlenlp/trainer/training_args.py | 6 +++++ .../trainer/utils/zero_cost_checkpoint.py | 27 ++++++++++++------- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b48679dbf26a..cea31a9ac8ba 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1370,6 +1370,16 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch + + # For ZCC EMA + if self.args.enable_zero_cost_checkpoint: + tr_loss_for_zcc = tr_loss.clone() + dist.all_reduce( + tr_loss_for_zcc, dist.ReduceOp.SUM + ) # 3级并行时,每个pp下的loss会广播,全局reduce-mean的时候,分子分母都会乘以pp_world_size,结果会被约掉 + tr_loss_for_zcc_scalar = tr_loss_for_zcc.item() / dist.get_world_size() + self.state.loss = tr_loss_for_zcc_scalar + self.state.consumed_samples = ( self.state.global_step * args.per_device_train_batch_size diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 30a3e7b3dc62..f2c5a5acf5d0 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1045,6 +1045,12 @@ class TrainingArguments: default=1, metadata={"help": "Interval between updating EMA parameters."}, ) + zcc_ema_loss_threshold: Optional[float] = field( + default=None, + metadata={ + "help": "If set not None, only do EMA when the training loss is smaller than the threshold value" + }, + ) save_tokenizer: Optional[bool] = field( default=True, metadata={"help": "Save tokenizer to output_dir."}, diff --git a/paddlenlp/trainer/utils/zero_cost_checkpoint.py b/paddlenlp/trainer/utils/zero_cost_checkpoint.py index 5d92cc91cf95..a02ae37c899f 100644 --- a/paddlenlp/trainer/utils/zero_cost_checkpoint.py +++ b/paddlenlp/trainer/utils/zero_cost_checkpoint.py @@ -171,10 +171,11 @@ def ema_reset(self): self.ema_buffer_modele_params = None @imperative_base.no_grad() - def ema_accumulate(self): + def ema_accumulate(self, global_step, loss, zcc_ema_loss_threshold): """ perform ema update : ` \alpha * EMA + (1-\alpha) + model` - build `self.ema_buffer` if necessary + buid `self.ema_buffer` if necessary + when loss < threshold, do ema update """ # logger.info(f'[ZCC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}') # do update: ema = alpha * ema + (1-alpha) * model @@ -183,14 +184,16 @@ def ema_accumulate(self): cpu_master_weights = self.optimizer_fusion_storage_helper.cpu_buffer._slice( self.master_min_offset, self.master_max_offset ).cpu() - self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights - # logger.info(f'[ZCC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}') - for index, ema_buf in self.ema_buffer_model_params.items(): - _, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index] - updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf - self.ema_buffer_model_params[index] = updated_ema + if zcc_ema_loss_threshold is None or loss < zcc_ema_loss_threshold: + self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights + for index, ema_buf in self.ema_buffer_model_params.items(): + _, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index] + updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf + self.ema_buffer_model_params[index] = updated_ema + logger.info(f"[ZCC EMA] accmulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done") + else: + logger.info(f"[ZCC EMA] accmulating SKIP for global_step:{global_step}, because loss:{loss} > threshold:{zcc_ema_loss_threshold}") - logger.info(f"[ZCC EMA] accumulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done") @imperative_base.no_grad() def ema_state_dict(self): @@ -771,7 +774,11 @@ def process_offload_task(self, dump, global_step): self.global_step.value = global_step if self.ema_coef is not None: - self.zcc_ema_processor.ema_accumulate() + self.zcc_ema_processor.ema_accumulate( + self.trainer_state.global_step, + self.trainer_state.loss, + self.training_args_content.zcc_ema_loss_threshold + ) # continue to process dumping task at the last chunk if self.offloaded_numels == self.all_numel: