@@ -140,6 +140,10 @@ def ppo_loss(
140
140
# values = values * jnp.fmax(value_std, 1e-6) + value_mean
141
141
target_values = (target_values - value_mean ) / jnp .fmax (value_std , 1e-6 )
142
142
policy_log_probs = ppo_networks .log_prob (distribution_params , actions )
143
+ if ppo_networks .extra_loss is not None :
144
+ extra_loss = ppo_networks .extra_loss (distribution_params )
145
+ else :
146
+ extra_loss = 0.0
143
147
key , sub_key = jax .random .split (key )
144
148
policy_entropies = ppo_networks .entropy (distribution_params , sub_key )
145
149
@@ -168,10 +172,11 @@ def ppo_loss(
168
172
# https://arxiv.org/pdf/2006.05990.pdf
169
173
value_loss = jnp .mean (unclipped_value_loss )
170
174
171
- total_ppo_loss = total_policy_loss + value_cost * value_loss
175
+ total_ppo_loss = total_policy_loss + value_cost * value_loss + extra_loss
172
176
return total_ppo_loss , { # pytype: disable=bad-return-type # numpy-scalars
173
177
'loss_total' : total_ppo_loss ,
174
178
'loss_policy_total' : total_policy_loss ,
179
+ 'loss_extra' : extra_loss ,
175
180
'loss_policy_pg' : clipped_ppo_policy_loss ,
176
181
'loss_policy_entropy' : policy_entropy_loss ,
177
182
'loss_critic' : value_loss ,
0 commit comments