@@ -173,10 +173,11 @@ def ema_reset(self):
173
173
self .ema_buffer_modele_params = None
174
174
175
175
@imperative_base .no_grad ()
176
- def ema_accumulate (self ):
176
+ def ema_accumulate (self , global_step , loss , zcc_ema_loss_threshold ):
177
177
"""
178
178
perform ema update : ` \a lpha * EMA + (1-\a lpha) + model`
179
- build `self.ema_buffer` if necessary
179
+ buid `self.ema_buffer` if necessary
180
+ when loss < threshold, do ema update
180
181
"""
181
182
# logger.info(f'[ZCC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
182
183
# do update: ema = alpha * ema + (1-alpha) * model
@@ -185,14 +186,19 @@ def ema_accumulate(self):
185
186
cpu_master_weights = self .optimizer_fusion_storage_helper .cpu_buffer ._slice (
186
187
self .master_min_offset , self .master_max_offset
187
188
).cpu ()
188
- self .ema_buffer = self .ema_coef * self .ema_buffer + (1 - self .ema_coef ) * cpu_master_weights
189
- # logger.info(f'[ZCC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
190
- for index , ema_buf in self .ema_buffer_model_params .items ():
191
- _ , cpu_buf = self .param_fusion_storage_helper .inited_buffers [index ]
192
- updated_ema = self .ema_coef * ema_buf + (1 - self .ema_coef ) * cpu_buf
193
- self .ema_buffer_model_params [index ] = updated_ema
194
-
195
- logger .info (f"[ZCC EMA] accumulating, buffer type:{ self .ema_buffer .place } { self .ema_buffer .dtype } , done" )
189
+ if zcc_ema_loss_threshold is None or loss < zcc_ema_loss_threshold :
190
+ self .ema_buffer = self .ema_coef * self .ema_buffer + (1 - self .ema_coef ) * cpu_master_weights
191
+ for index , ema_buf in self .ema_buffer_model_params .items ():
192
+ _ , cpu_buf = self .param_fusion_storage_helper .inited_buffers [index ]
193
+ updated_ema = self .ema_coef * ema_buf + (1 - self .ema_coef ) * cpu_buf
194
+ self .ema_buffer_model_params [index ] = updated_ema
195
+ logger .info (
196
+ f"[ZCC EMA] accmulating, buffer type:{ self .ema_buffer .place } { self .ema_buffer .dtype } , done"
197
+ )
198
+ else :
199
+ logger .info (
200
+ f"[ZCC EMA] accmulating SKIP for global_step:{ global_step } , because loss:{ loss } > threshold:{ zcc_ema_loss_threshold } "
201
+ )
196
202
197
203
@imperative_base .no_grad ()
198
204
def ema_state_dict (self ):
@@ -790,7 +796,11 @@ def process_offload_task(self, dump, global_step):
790
796
self .global_step .value = global_step
791
797
792
798
if self .ema_coef is not None :
793
- self .zcc_ema_processor .ema_accumulate ()
799
+ self .zcc_ema_processor .ema_accumulate (
800
+ self .trainer_state .global_step ,
801
+ self .trainer_state .loss ,
802
+ self .training_args_content .zcc_ema_loss_threshold ,
803
+ )
794
804
795
805
# continue to process dumping task at the last chunk
796
806
if self .offloaded_numels == self .all_numel :
0 commit comments