42
42
from paddlenlp .transformers import (
43
43
AutoConfig ,
44
44
AutoModelForCausalLM ,
45
+ AutoModelForTokenClassification ,
45
46
AutoTokenizer ,
46
47
PretrainedConfig ,
47
48
)
@@ -134,7 +135,6 @@ def create_actor_models(
134
135
)
135
136
if not training_args .autotuner_benchmark :
136
137
reference_model .set_state_dict (actor_model .state_dict ())
137
-
138
138
actor_tokenizer = AutoTokenizer .from_pretrained (
139
139
model_args .actor_model_name_or_path ,
140
140
model_max_length = data_args .max_length ,
@@ -210,46 +210,43 @@ def create_critic_models(
210
210
data_args : DataArgument ,
211
211
training_args : TrainingArguments ,
212
212
common_config : Dict ,
213
- reward_model ,
214
213
):
215
214
with timers_scope_runtimer ("Critic model loading time" ):
216
- reward_model_config = reward_model .config
217
- if model_args .critic_model_name_or_path is None :
218
- model_args .critic_model_name_or_path = model_args .reward_model_name_or_path
219
- critic_model = AutoModelForScore .from_config (
220
- reward_model_config ,
221
- dtype = training_args .model_dtype ,
222
- score_type = "critic" ,
223
- do_normalize = False ,
224
- clip_range_value = training_args .clip_range_value ,
225
- ** common_config ,
215
+ critic_model_config = AutoConfig .from_pretrained (
216
+ model_args .critic_model_name_or_path ,
217
+ tensor_parallel_output = training_args .tensor_parallel_output ,
218
+ tensor_parallel_degree = training_args .tensor_parallel_degree ,
219
+ tensor_parallel_rank = training_args .tensor_parallel_rank ,
220
+ dtype = training_args .model_dtype ,
221
+ recompute = training_args .critic_recompute ,
222
+ recompute_granularity = model_args .critic_recompute_granularity ,
223
+ recompute_use_reentrant = training_args .recompute_use_reentrant ,
224
+ ** common_config ,
225
+ )
226
+ LlmMetaConfig .set_llm_config (critic_model_config , training_args )
227
+
228
+ critic_model_config .max_position_embeddings = data_args .max_length
229
+ critic_model_config .use_sparse_head_and_loss_fn = False
230
+ critic_model_config .num_labels = 1
231
+ critic_model_config .classifier_dropout = 0.0
232
+ critic_model_config .hidden_dropout = 0.0
233
+ logger .info (f"Loading Critic model with config:\n \t { critic_model_config } \n " )
234
+
235
+ if not training_args .autotuner_benchmark :
236
+ critic_model = AutoModelForTokenClassification .from_pretrained (
237
+ model_args .critic_model_name_or_path ,
238
+ config = critic_model_config ,
226
239
)
227
- if not training_args .autotuner_benchmark :
228
- critic_model .set_state_dict (reward_model .state_dict ())
229
240
else :
230
- if not training_args .autotuner_benchmark :
231
- critic_model = AutoModelForScore .from_pretrained (
232
- model_args .critic_model_name_or_path ,
233
- config = reward_model_config ,
234
- score_type = "critic" ,
235
- do_normalize = False ,
236
- clip_range_value = training_args .clip_range_value ,
237
- ** common_config ,
238
- )
239
- else :
240
- critic_model = AutoModelForScore .from_config (
241
- reward_model_config ,
242
- score_type = "critic" ,
243
- do_normalize = False ,
244
- clip_range_value = training_args .clip_range_value ,
245
- ** common_config ,
246
- )
241
+ critic_model = AutoModelForTokenClassification .from_config (
242
+ critic_model_config ,
243
+ )
247
244
248
245
critic_tokenizer = AutoTokenizer .from_pretrained (
249
246
model_args .critic_model_name_or_path ,
250
247
model_max_length = data_args .max_length ,
251
248
padding_side = "left" ,
252
- tokenizer_alpha = model_args .reward_critic_tokenizer_alpha ,
249
+ tokenizer_alpha = model_args .critic_tokenizer_alpha ,
253
250
use_fast = True ,
254
251
)
255
252
if critic_tokenizer .pad_token_id is None :
@@ -261,16 +258,16 @@ def create_critic_models(
261
258
if training_args .eval_mode == "single" :
262
259
config .tensor_parallel_degree = - 1
263
260
config .tensor_parallel_rank = 0
264
- with timers_scope_runtimer ("Reward critic eval model loading time" ):
265
- critic_eval_model = AutoModelForScore .from_config (config )
261
+ with timers_scope_runtimer ("Critic eval model loading time" ):
262
+ critic_eval_model = AutoModelForTokenClassification .from_config (config )
266
263
else :
267
264
critic_eval_model = None
268
265
269
266
return critic_model , critic_eval_model , critic_tokenizer
270
267
271
268
272
269
def create_rl_dataset (data_args , training_args , tokenizer ):
273
- requires_label = True if training_args .use_rm_server else False
270
+ requires_label = True if training_args .use_rm_server or training_args . use_rule_reward else False
274
271
train_ds = RLHFDataset (
275
272
dataset_name_or_path = data_args .train_datasets ,
276
273
tokenizer = tokenizer ,
@@ -333,15 +330,16 @@ def main():
333
330
actor_model , actor_eval_model , reference_model , actor_tokenizer = create_actor_models (
334
331
model_args , data_args , training_args , common_config , reshard_controller
335
332
)
336
-
337
- if not training_args .use_rm_server and model_args .reward_model_name_or_path is not None :
333
+ if training_args .use_rule_reward :
334
+ reward_model , reward_tokenizer = None , actor_tokenizer
335
+ elif not training_args .use_rm_server and model_args .reward_model_name_or_path is not None :
338
336
reward_model , reward_tokenizer = create_reward_models (model_args , data_args , training_args , common_config )
339
337
else :
340
338
reward_model , reward_tokenizer = model_args .reward_server , actor_tokenizer
341
339
342
340
if training_args .rl_algorithm == "ppo" :
343
341
critic_model , critic_eval_model , critic_tokenizer = create_critic_models (
344
- model_args , data_args , training_args , common_config , reward_model
342
+ model_args , data_args , training_args , common_config
345
343
)
346
344
else :
347
345
critic_model , critic_eval_model , critic_tokenizer = None , None , None
@@ -355,15 +353,23 @@ def main():
355
353
offload_tensor_to_cpu ((reference_model , "freeze_model" ))
356
354
357
355
if training_args .rl_algorithm == "ppo" :
358
- offload_tensor_to_cpu ((reward_model , "freeze_model" ))
356
+ if not training_args .use_rm_server and not training_args .use_rule_reward :
357
+ offload_tensor_to_cpu ((reward_model , "freeze_model" ))
359
358
if critic_eval_model is not None :
360
359
offload_tensor_to_cpu ((critic_eval_model , "freeze_model" ))
361
360
362
361
# NOTE(gongenlei): release memory_reserved_size to equal to memory_allocated_size
363
362
paddle .device .cuda .empty_cache ()
364
363
365
364
def compute_metrics (eval_preds ):
366
- accuracy = (eval_preds .predictions == 3 ).astype ("float32" ).mean ().item ()
365
+ '''
366
+ If "use_rm_server" is TRUE, the score ranges from -3 to 3, with 3 being the only correct score (format + result).
367
+ If using the "Regularized Matching Function (use_rule_reward=True)" (currently only implemented for the gsm8k dataset), the score ranges from 0 to 1.
368
+ '''
369
+ if training_args .use_rule_reward :
370
+ accuracy = (eval_preds .predictions == 1 ).astype ("float32" ).mean ().item ()
371
+ else :
372
+ accuracy = (eval_preds .predictions == 3 ).astype ("float32" ).mean ().item ()
367
373
return {"accuracy" : accuracy }
368
374
369
375
try :
@@ -389,7 +395,7 @@ def compute_metrics(eval_preds):
389
395
data_collator = partial (
390
396
collate_fn ,
391
397
pad_token_id = actor_tokenizer .pad_token_id ,
392
- requires_label = True if training_args .use_rm_server else False ,
398
+ requires_label = True if training_args .use_rm_server or training_args . use_rule_reward else False ,
393
399
max_prompt_len = data_args .max_prompt_len if training_args .balance_batch else None ,
394
400
), # NOTE: enforce prompt padding to max_prompt_len when using balance_batch
395
401
compute_metrics = compute_metrics , # TODO: only used for grpo (kk datasets)
0 commit comments