Skip to content

Commit da77c0b

Browse files
committed
weight decay
1 parent 0cc7a79 commit da77c0b

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

improving_transformers_world_model/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def __init__(
396396
optim_klass = AdoptAtan2,
397397
actor_lr = 1e-4,
398398
critic_lr = 1e-4,
399+
actor_weight_decay = 1e-3,
400+
critic_weight_decay = 1e-3,
399401
max_grad_norm = 0.5,
400402
actor_optim_kwargs: dict = dict(),
401403
critic_optim_kwargs: dict = dict(),
@@ -429,8 +431,8 @@ def __init__(
429431

430432
self.max_grad_norm = max_grad_norm
431433

432-
self.actor_optim = optim_klass((*actor.parameters(), *impala.parameters()), lr = actor_lr, **actor_optim_kwargs)
433-
self.critic_optim = optim_klass((*critic.parameters(), *impala.parameters()), lr = actor_lr, **actor_optim_kwargs)
434+
self.actor_optim = optim_klass((*actor.parameters(), *impala.parameters()), lr = actor_lr, weight_decay = actor_weight_decay, **actor_optim_kwargs)
435+
self.critic_optim = optim_klass((*critic.parameters(), *impala.parameters()), lr = critic_lr, weight_decay = critic_weight_decay, **critic_optim_kwargs)
434436

435437
# use a batch norm for standardizing the target - section A.1.2 in paper
436438

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "improving-transformers-world-model"
3-
version = "0.0.53"
3+
version = "0.0.54"
44
description = "Improving Transformers World Model for RL"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)