Skip to content

Commit ed4e57a

Browse files
Acme Contributorcopybara-github
authored andcommitted
Add Wasserstein Policy Optimization (WPO, http://arxiv.org/pdf/2505.00663) to Acme
PiperOrigin-RevId: 754303872 Change-Id: I1e360c066abc5a122289b8ff87bb3794674dc61f
1 parent 7578c75 commit ed4e57a

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

acme/jax/losses/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@
1818
from acme.jax.losses.mpo import MPO
1919
from acme.jax.losses.mpo import MPOParams
2020
from acme.jax.losses.mpo import MPOStats
21+
from acme.jax.losses.wpo import WPO
22+
from acme.jax.losses.wpo import WPOParams
23+
from acme.jax.losses.wpo import WPOStats
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Example running WPO on continuous control tasks."""
16+
17+
from absl import flags
18+
from acme import specs
19+
from acme.agents.jax import wpo
20+
from acme.agents.jax.wpo import types as wpo_types
21+
import helpers
22+
from absl import app
23+
from acme.jax import experiments
24+
from acme.utils import lp_utils
25+
import launchpad as lp
26+
27+
RUN_DISTRIBUTED = flags.DEFINE_bool(
28+
'run_distributed', True, 'Should an agent be executed in a distributed '
29+
'way. If False, will run single-threaded.')
30+
ENV_NAME = flags.DEFINE_string(
31+
'env_name', 'gym:HalfCheetah-v2',
32+
'What environment to run on, in the format {gym|control}:{task}, '
33+
'where "control" refers to the DM control suite. DM Control tasks are '
34+
'further split into {domain_name}:{task_name}.')
35+
SEED = flags.DEFINE_integer('seed', 0, 'Random seed.')
36+
NUM_STEPS = flags.DEFINE_integer(
37+
'num_steps', 1_000_000,
38+
'Number of environment steps to run the experiment for.')
39+
EVAL_EVERY = flags.DEFINE_integer(
40+
'eval_every', 50_000,
41+
'How often (in actor environment steps) to run evaluation episodes.')
42+
EVAL_EPISODES = flags.DEFINE_integer(
43+
'evaluation_episodes', 10,
44+
'Number of evaluation episodes to run periodically.')
45+
46+
47+
def build_experiment_config():
48+
"""Builds MPO experiment config which can be executed in different ways."""
49+
suite, task = ENV_NAME.value.split(':', 1)
50+
51+
def network_factory(spec: specs.EnvironmentSpec) -> wpo.WPONetworks:
52+
return wpo.make_control_networks(
53+
spec,
54+
policy_layer_sizes=(256, 256, 256),
55+
critic_layer_sizes=(256, 256, 256),
56+
policy_init_scale=0.5)
57+
58+
# Configure and construct the agent builder.
59+
config = wpo.WPOConfig(
60+
policy_loss_config=wpo_types.GaussianPolicyLossConfig(epsilon_mean=0.01),
61+
samples_per_insert=64,
62+
learning_rate=3e-4,
63+
experience_type=wpo_types.FromTransitions(n_step=5),
64+
dual_learning_rate=0.0) # Turn off dual learning.
65+
agent_builder = wpo.WPOBuilder(config, sgd_steps_per_learner_step=1)
66+
67+
return experiments.ExperimentConfig(
68+
builder=agent_builder,
69+
environment_factory=lambda _: helpers.make_environment(suite, task),
70+
network_factory=network_factory,
71+
seed=SEED.value,
72+
max_num_actor_steps=NUM_STEPS.value)
73+
74+
75+
def main(_):
76+
config = build_experiment_config()
77+
if RUN_DISTRIBUTED.value:
78+
program = experiments.make_distributed_experiment(
79+
experiment=config, num_actors=4)
80+
lp.launch(program, xm_resources=lp_utils.make_xm_docker_resources(program))
81+
else:
82+
experiments.run_experiment(
83+
experiment=config,
84+
eval_every=EVAL_EVERY.value,
85+
num_eval_episodes=EVAL_EPISODES.value)
86+
87+
88+
if __name__ == '__main__':
89+
app.run(main)

0 commit comments

Comments
 (0)