Skip to content

Commit 20a584f

Browse files
authored
[cherry-pick] add zcc_ema_loss_threshold args to avoid merging models with loss spike (#11079)
1 parent 68645cc commit 20a584f

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,16 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
13841384

13851385
self.state.global_step += 1
13861386
self.state.epoch = epoch + (step + 1) / steps_in_epoch
1387+
1388+
# For ZCC EMA
1389+
if self.args.enable_zero_cost_checkpoint:
1390+
tr_loss_for_zcc = tr_loss.clone()
1391+
dist.all_reduce(
1392+
tr_loss_for_zcc, dist.ReduceOp.SUM
1393+
) # 3级并行时,每个pp下的loss会广播,全局reduce-mean的时候,分子分母都会乘以pp_world_size,结果会被约掉
1394+
tr_loss_for_zcc_scalar = tr_loss_for_zcc.item() / dist.get_world_size()
1395+
self.state.loss = tr_loss_for_zcc_scalar
1396+
13871397
self.state.consumed_samples = (
13881398
self.state.global_step
13891399
* args.per_device_train_batch_size

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,10 @@ class TrainingArguments:
10501050
default=1,
10511051
metadata={"help": "Interval between updating EMA parameters."},
10521052
)
1053+
zcc_ema_loss_threshold: Optional[float] = field(
1054+
default=None,
1055+
metadata={"help": "If set not None, only do EMA when the training loss is smaller than the threshold value"},
1056+
)
10531057
save_tokenizer: Optional[bool] = field(
10541058
default=True,
10551059
metadata={"help": "Save tokenizer to output_dir."},

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ def ema_reset(self):
173173
self.ema_buffer_modele_params = None
174174

175175
@imperative_base.no_grad()
176-
def ema_accumulate(self):
176+
def ema_accumulate(self, global_step, loss, zcc_ema_loss_threshold):
177177
"""
178178
perform ema update : ` \alpha * EMA + (1-\alpha) + model`
179-
build `self.ema_buffer` if necessary
179+
buid `self.ema_buffer` if necessary
180+
when loss < threshold, do ema update
180181
"""
181182
# logger.info(f'[ZCC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
182183
# do update: ema = alpha * ema + (1-alpha) * model
@@ -185,14 +186,19 @@ def ema_accumulate(self):
185186
cpu_master_weights = self.optimizer_fusion_storage_helper.cpu_buffer._slice(
186187
self.master_min_offset, self.master_max_offset
187188
).cpu()
188-
self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights
189-
# logger.info(f'[ZCC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
190-
for index, ema_buf in self.ema_buffer_model_params.items():
191-
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
192-
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf
193-
self.ema_buffer_model_params[index] = updated_ema
194-
195-
logger.info(f"[ZCC EMA] accumulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done")
189+
if zcc_ema_loss_threshold is None or loss < zcc_ema_loss_threshold:
190+
self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights
191+
for index, ema_buf in self.ema_buffer_model_params.items():
192+
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
193+
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf
194+
self.ema_buffer_model_params[index] = updated_ema
195+
logger.info(
196+
f"[ZCC EMA] accmulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done"
197+
)
198+
else:
199+
logger.info(
200+
f"[ZCC EMA] accmulating SKIP for global_step:{global_step}, because loss:{loss} > threshold:{zcc_ema_loss_threshold}"
201+
)
196202

197203
@imperative_base.no_grad()
198204
def ema_state_dict(self):
@@ -790,7 +796,11 @@ def process_offload_task(self, dump, global_step):
790796
self.global_step.value = global_step
791797

792798
if self.ema_coef is not None:
793-
self.zcc_ema_processor.ema_accumulate()
799+
self.zcc_ema_processor.ema_accumulate(
800+
self.trainer_state.global_step,
801+
self.trainer_state.loss,
802+
self.training_args_content.zcc_ema_loss_threshold,
803+
)
794804

795805
# continue to process dumping task at the last chunk
796806
if self.offloaded_numels == self.all_numel:

0 commit comments

Comments
 (0)