Skip to content

Commit a0cebd6

Browse files
committed
take a step towards the proposed follow up research in the paper, incorporating the world model embed when choosing actions during real env rollouts
1 parent 14c04ca commit a0cebd6

File tree

4 files changed

+105
-9
lines changed

4 files changed

+105
-9
lines changed

improving_transformers_world_model/agent.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ def calc_target_and_gae(
112112

113113
return returns, gae
114114

115+
# FiLM for conditioning policy network on world model embed - suggested for follow up research in the paper
116+
117+
class FiLM(Module):
118+
def __init__(
119+
self,
120+
dim,
121+
dim_out
122+
):
123+
super().__init__()
124+
self.to_gamma = nn.Linear(dim, dim_out, bias = False)
125+
self.to_beta = nn.Linear(dim, dim_out, bias = False)
126+
127+
nn.init.zeros_(self.to_gamma.weight)
128+
nn.init.zeros_(self.to_beta.weight)
129+
130+
def forward(
131+
self,
132+
x: Float['... d'],
133+
cond: Float['... d']
134+
):
135+
gamma, beta = self.to_gamma(cond), self.to_beta(cond)
136+
137+
return x * (gamma + 1.) + beta
138+
115139
# symbol extractor
116140
# detailed in section C.3
117141

@@ -267,6 +291,7 @@ def __init__(
267291
dim,
268292
*,
269293
num_actions,
294+
dim_world_model_embed = None,
270295
num_layers = 3,
271296
expansion_factor = 2.,
272297
):
@@ -292,9 +317,17 @@ def __init__(
292317

293318
self.to_actions_pred = nn.Linear(dim, num_actions)
294319

320+
# able to condition on world model embed when predicting action - using classic film
321+
322+
self.can_cond_on_world_model = exists(dim_world_model_embed)
323+
324+
if self.can_cond_on_world_model:
325+
self.world_model_film = FiLM(dim_world_model_embed, dim)
326+
295327
def forward(
296328
self,
297329
state: Float['b c h w'],
330+
world_model_embed: Float['b d'] | None = None,
298331
sample_action = False
299332
) -> (
300333
Float['b'] |
@@ -303,6 +336,11 @@ def forward(
303336

304337
embed = self.proj_in(state)
305338

339+
if exists(world_model_embed):
340+
assert exists(self.world_model_film), f'`dim_world_model_embed` must be set on `Actor` to utilize world model for prediction'
341+
342+
embed = self.world_model_film(embed, world_model_embed)
343+
306344
for layer in self.layers:
307345
embed = layer(embed) + embed
308346

@@ -636,6 +674,7 @@ def interact_with_env(
636674
self,
637675
env,
638676
memories: Memories | None = None,
677+
world_model: WorldModel | None = None,
639678
max_steps = float('inf')
640679

641680
) -> MemoriesWithNextState:
@@ -662,13 +701,39 @@ def interact_with_env(
662701
last_done = dones[0, -1]
663702
time_step = states.shape[2] + 1
664703

704+
# maybe conditioning actor with learned world model embed
705+
706+
if exists(world_model):
707+
world_model_cache = None
708+
665709
while time_step < max_steps and not last_done:
666710

711+
world_model_embed = None
712+
713+
if exists(world_model):
714+
with torch.no_grad():
715+
world_model.eval()
716+
717+
world_model_embed, world_model_cache = world_model(
718+
state_or_token_ids = states[:, :, -1:],
719+
actions = actions[:, -1:],
720+
rewards = rewards[:, -1:],
721+
cache = world_model_cache,
722+
remove_cache_len_from_time = False,
723+
return_embed = True,
724+
return_cache = True,
725+
return_loss = False
726+
)
727+
728+
world_model_embed = rearrange(world_model_embed, '1 1 d -> 1 d')
729+
730+
# impala + actor - todo: cleanup the noisy tensor (un)squeezing ops
731+
667732
next_state = rearrange(next_state, 'c h w -> 1 c h w')
668733

669734
actor_critic_input, rnn_hidden = self.impala(next_state)
670735

671-
action, action_log_prob = self.actor(actor_critic_input, sample_action = True)
736+
action, action_log_prob = self.actor(actor_critic_input, world_model_embed = world_model_embed, sample_action = True)
672737

673738
next_state, next_reward, next_done = to_device_decorator(env)(action)
674739

improving_transformers_world_model/world_model.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ def forward(
770770
detach_cache = False,
771771
return_loss = True,
772772
return_loss_breakdown = False,
773+
return_embed = False,
773774
freeze_tokenizer = True
774775
):
775776
batch = state_or_token_ids.shape[0]
@@ -784,12 +785,18 @@ def forward(
784785
token_ids = state_or_token_ids
785786

786787
if return_loss:
788+
assert token_ids.shape[1] > 1
789+
787790
token_ids, state_labels = token_ids[:, :-1], token_ids[:, 1:]
788791

789-
is_terminal_labels = is_terminal[:, 1:]
792+
if exists(is_terminal):
793+
is_terminal_labels = is_terminal[:, 1:]
790794

791-
actions, last_action = actions[:, :-1], actions[:, -1:]
792-
rewards, last_reward = rewards[:, :-1], rewards[:, -1:]
795+
if exists(actions):
796+
actions, last_action = actions[:, :-1], actions[:, -1:]
797+
798+
if exists(rewards):
799+
rewards, last_reward = rewards[:, :-1], rewards[:, -1:]
793800

794801
# either use own learned token embeddings
795802
# or project the codes (which are just the nearest neighbor memorized patch) and project
@@ -819,7 +826,9 @@ def forward(
819826
actions = actions.masked_fill(no_actions, 0)
820827
action_embeds = self.action_embed(actions)
821828

822-
action_embeds = einx.where('b t n, b t n d, -> b t n d', ~no_actions, action_embeds, 0.)
829+
if not is_empty(action_embeds):
830+
action_embeds = einx.where('b t n, b t n d, -> b t n d', ~no_actions, action_embeds, 0.)
831+
823832
action_embeds = reduce(action_embeds, 'b t n d -> b t d', 'sum')
824833

825834
action_embed_sos = repeat(self.action_embed_sos, 'd -> b 1 d', b = batch)
@@ -863,6 +872,14 @@ def inverse_time(t):
863872

864873
embeds_with_time = reduce(embeds_with_space, 'b ... d -> b 1 d', 'mean')
865874

875+
# maybe return embed
876+
877+
if return_embed:
878+
if not return_cache:
879+
return embeds_with_time
880+
881+
return embeds_with_time, next_cache
882+
866883
# reward and terminal
867884

868885
reward_logits = None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "improving-transformers-world-model"
3-
version = "0.0.57"
3+
version = "0.0.58"
44
description = "Improving Transformers World Model for RL"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_agent.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1313

1414
@pytest.mark.parametrize('critic_use_regression', (False, True))
15+
@pytest.mark.parametrize('actor_use_world_model_embed', (False, True))
1516
def test_agent(
16-
critic_use_regression
17+
critic_use_regression,
18+
actor_use_world_model_embed
1719
):
1820

1921
# world model
@@ -53,6 +55,7 @@ def test_agent(
5355
actor = dict(
5456
dim = 32,
5557
num_actions = 5,
58+
dim_world_model_embed = 32 if actor_use_world_model_embed else None
5659
),
5760
critic = dict(
5861
dim = 64,
@@ -62,9 +65,17 @@ def test_agent(
6265

6366
env = Env((3, 63, 63))
6467

65-
dream_memories = agent(world_model, state[0, :, 0], max_steps = 5)
68+
dream_memories = agent(
69+
world_model,
70+
state[0, :, 0],
71+
max_steps = 5
72+
)
6673

67-
real_memories = agent.interact_with_env(env, max_steps = 5)
74+
real_memories = agent.interact_with_env(
75+
env,
76+
world_model = world_model if actor_use_world_model_embed else None,
77+
max_steps = 5
78+
)
6879

6980
agent.learn([dream_memories, real_memories])
7081

@@ -93,6 +104,9 @@ def world_model_burn_in():
93104
actions = torch.randint(0, 5, (2, 20, 1))
94105
is_terminal = torch.randint(0, 2, (2, 20)).bool()
95106

107+
loss = world_model(state, actions = actions, rewards = rewards, is_terminal = is_terminal)
108+
loss.backward()
109+
96110
_, burn_in_cache = world_model(
97111
state,
98112
actions = actions,

0 commit comments

Comments
 (0)