Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,16 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):

self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch

# For ZCC EMA
if self.args.enable_zero_cost_checkpoint:
tr_loss_for_zcc = tr_loss.clone()
dist.all_reduce(
tr_loss_for_zcc, dist.ReduceOp.SUM
) # 3级并行时,每个pp下的loss会广播,全局reduce-mean的时候,分子分母都会乘以pp_world_size,结果会被约掉
tr_loss_for_zcc_scalar = tr_loss_for_zcc.item() / dist.get_world_size()
self.state.loss = tr_loss_for_zcc_scalar

self.state.consumed_samples = (
self.state.global_step
* args.per_device_train_batch_size
Expand Down
6 changes: 6 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,12 @@ class TrainingArguments:
default=1,
metadata={"help": "Interval between updating EMA parameters."},
)
zcc_ema_loss_threshold: Optional[float] = field(
default=None,
metadata={
"help": "If set not None, only do EMA when the training loss is smaller than the threshold value"
},
)
save_tokenizer: Optional[bool] = field(
default=True,
metadata={"help": "Save tokenizer to output_dir."},
Expand Down
27 changes: 17 additions & 10 deletions paddlenlp/trainer/utils/zero_cost_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,11 @@ def ema_reset(self):
self.ema_buffer_modele_params = None

@imperative_base.no_grad()
def ema_accumulate(self):
def ema_accumulate(self, global_step, loss, zcc_ema_loss_threshold):
"""
perform ema update : ` \alpha * EMA + (1-\alpha) + model`
build `self.ema_buffer` if necessary
buid `self.ema_buffer` if necessary
when loss < threshold, do ema update
"""
# logger.info(f'[ZCC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
# do update: ema = alpha * ema + (1-alpha) * model
Expand All @@ -183,14 +184,16 @@ def ema_accumulate(self):
cpu_master_weights = self.optimizer_fusion_storage_helper.cpu_buffer._slice(
self.master_min_offset, self.master_max_offset
).cpu()
self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights
# logger.info(f'[ZCC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
for index, ema_buf in self.ema_buffer_model_params.items():
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf
self.ema_buffer_model_params[index] = updated_ema
if zcc_ema_loss_threshold is None or loss < zcc_ema_loss_threshold:
self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights
for index, ema_buf in self.ema_buffer_model_params.items():
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf
self.ema_buffer_model_params[index] = updated_ema
logger.info(f"[ZCC EMA] accmulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done")
else:
logger.info(f"[ZCC EMA] accmulating SKIP for global_step:{global_step}, because loss:{loss} > threshold:{zcc_ema_loss_threshold}")

logger.info(f"[ZCC EMA] accumulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done")

@imperative_base.no_grad()
def ema_state_dict(self):
Expand Down Expand Up @@ -771,7 +774,11 @@ def process_offload_task(self, dump, global_step):
self.global_step.value = global_step

if self.ema_coef is not None:
self.zcc_ema_processor.ema_accumulate()
self.zcc_ema_processor.ema_accumulate(
self.trainer_state.global_step,
self.trainer_state.loss,
self.training_args_content.zcc_ema_loss_threshold
)

# continue to process dumping task at the last chunk
if self.offloaded_numels == self.all_numel:
Expand Down
Loading