Skip to content

Commit 8ec1a77

Browse files
Allen Wangcopybara-github
authored andcommitted
Add an optional customizable loss callable in the PPONetworks, whose output will be added to total loss.
PiperOrigin-RevId: 785739565 Change-Id: I9046e7178e0dfe47c8b276cb0b78d7906245cde2
1 parent 6828035 commit 8ec1a77

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

acme/agents/jax/ppo/learning.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ def ppo_loss(
140140
# values = values * jnp.fmax(value_std, 1e-6) + value_mean
141141
target_values = (target_values - value_mean) / jnp.fmax(value_std, 1e-6)
142142
policy_log_probs = ppo_networks.log_prob(distribution_params, actions)
143+
if ppo_networks.extra_loss is not None:
144+
extra_loss = ppo_networks.extra_loss(distribution_params)
145+
else:
146+
extra_loss = 0.0
143147
key, sub_key = jax.random.split(key)
144148
policy_entropies = ppo_networks.entropy(distribution_params, sub_key)
145149

@@ -168,10 +172,11 @@ def ppo_loss(
168172
# https://arxiv.org/pdf/2006.05990.pdf
169173
value_loss = jnp.mean(unclipped_value_loss)
170174

171-
total_ppo_loss = total_policy_loss + value_cost * value_loss
175+
total_ppo_loss = total_policy_loss + value_cost * value_loss + extra_loss
172176
return total_ppo_loss, { # pytype: disable=bad-return-type # numpy-scalars
173177
'loss_total': total_ppo_loss,
174178
'loss_policy_total': total_policy_loss,
179+
'loss_extra': extra_loss,
175180
'loss_policy_pg': clipped_ppo_policy_loss,
176181
'loss_policy_entropy': policy_entropy_loss,
177182
'loss_critic': value_loss,

acme/agents/jax/ppo/networks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ class PPONetworks:
7777
entropy: EntropyFn
7878
sample: networks_lib.SampleFn
7979
sample_eval: Optional[networks_lib.SampleFn] = None
80+
extra_loss: Optional[
81+
Callable[[networks_lib.NetworkOutput], networks_lib.Value]
82+
] = None
8083

8184

8285
def make_inference_fn(

0 commit comments

Comments
 (0)