Skip to content

Commit f91d680

Browse files
committed
Better integration between gym & wave in examples
1 parent d43b396 commit f91d680

File tree

7 files changed

+57
-37
lines changed

7 files changed

+57
-37
lines changed

examples/deepq-model/render.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from model import DQN
3-
from reward import CustomReward
3+
from rewards import CustomReward
44

55
wave = True
66
render_episodes = 7
@@ -14,7 +14,7 @@
1414

1515
env = gym.make("LunarLander-v2")
1616

17-
model = DQN.load("trained-model")
17+
model = DQN.load("{}-trained-model".format("wave" if wave else "gym"))
1818

1919
episode = render_episodes
2020
reward_sum = 0

examples/deepq-model/rewards.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def reset(self):
5252

5353

5454
class LunarCustomReward(LunarLanderReward):
55-
""" Custom reward that applies no penalty for engine usage (infinite fuel)
55+
""" Custom reward that applies penalty for engine usage (infinite fuel)
5656
and allows more velocity for touching ground without crashing
5757
(the lander is more resistant to hits) """
5858

examples/deepq-model/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
from model import DQN
3-
from reward import CustomReward
3+
from rewards import CustomReward
44

55
# CONFIG
66
wave = True
@@ -24,6 +24,6 @@
2424
str_t = time.strftime("%H h, %M m, %S s", time.gmtime(t))
2525
print("Trained in {} during {} timesteps".format(str_t, learn_timesteps))
2626

27-
model.save("trained-model")
27+
model.save("{}-trained-model".format("wave" if wave else "gym"))
2828

2929
env.close()

examples/stable-baselines/common/callbacks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ def callback(locals_, globals_):
3131
steps = n_calls * C
3232
if steps // N >= next_index:
3333
print(
34-
"Saving model {}{} at step {} ...".format(
34+
"Saving snapshot {}{} at step {} ...".format(
3535
int(steps / order), order_str, steps
3636
)
3737
)
3838
locals_["self"].save(
39-
"{}{}{}{}".format(file_path, file_prefix, int(steps / order), order_str)
39+
str(
40+
file_path.joinpath(
41+
"{}{}{}".format(file_prefix, int(steps / order), order_str)
42+
)
43+
)
4044
)
4145
next_index = steps // N + 1
4246
return True

examples/stable-baselines/evaluate_model.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
""" Evaluates Wave Lunar Lander model """
2+
13
import common.shutup as shutup
24

35
shutup.future_warnings()
@@ -11,8 +13,8 @@
1113
from stable_baselines import A2C # noqa: E402
1214
from arlie.envs.lunar_lander.score import LunarLanderScore # noqa: E402
1315

16+
wave = True
1417
eval_timesteps = 1e5
15-
multi = True
1618
num_cpu = 12
1719

1820

@@ -23,9 +25,11 @@ def evaluate(env, model, num_steps=1000):
2325
:param num_steps: (int) number of timesteps to evaluate it
2426
:return: (float) Mean reward, (int) Number of episodes performed
2527
"""
26-
scores = [LunarLanderScore() for _ in range(env.num_envs)]
27-
episode_scores = [[0.0] for _ in range(env.num_envs)]
2828
episode_rewards = [[0.0] for _ in range(env.num_envs)]
29+
if wave:
30+
scores = [LunarLanderScore() for _ in range(env.num_envs)]
31+
episode_scores = [[0.0] for _ in range(env.num_envs)]
32+
2933
obs = env.reset()
3034
steps = (int)(num_steps // env.num_envs)
3135
for i in range(steps):
@@ -37,24 +41,30 @@ def evaluate(env, model, num_steps=1000):
3741

3842
# Stats
3943
for i in range(env.num_envs):
40-
scores[i].store_step(obs[i], actions[i], info[i])
41-
episode_scores[i][-1] = scores[i].get()
4244
episode_rewards[i][-1] += rewards[i]
45+
if wave:
46+
scores[i].store_step(obs[i], actions[i], info[i])
47+
episode_scores[i][-1] = scores[i].get()
4348
if dones[i]:
44-
episode_scores[i].append(0.0)
4549
episode_rewards[i].append(0.0)
46-
scores[i].reset()
50+
if wave:
51+
episode_scores[i].append(0.0)
52+
scores[i].reset()
4753

48-
mean_scores = [0.0 for _ in range(env.num_envs)]
4954
mean_rewards = [0.0 for _ in range(env.num_envs)]
55+
if wave:
56+
mean_scores = [0.0 for _ in range(env.num_envs)]
5057
n_episodes = 0
5158
for i in range(env.num_envs):
52-
mean_scores[i] = np.mean(episode_scores[i][:-1])
5359
mean_rewards[i] = np.mean(episode_rewards[i][:-1])
60+
if wave:
61+
mean_scores[i] = np.mean(episode_scores[i][:-1])
5462
n_episodes += len(episode_rewards[i]) - 1
5563

5664
# Compute mean reward
57-
mean_score = round(np.mean(mean_scores), 1)
65+
mean_score = "NaN"
66+
if wave:
67+
mean_score = round(np.mean(mean_scores), 1)
5868
mean_reward = round(np.mean(mean_rewards), 1)
5969

6070
return mean_score, mean_reward, n_episodes
@@ -70,10 +80,11 @@ def evaluate(env, model, num_steps=1000):
7080
print("Path '{}' does not exist.".format(model_path))
7181
exit(-1)
7282

83+
id = "LunarLander" if wave else "LunarLander-v2"
7384
if num_cpu > 1:
74-
env = make_multi_env(num_cpu, "LunarLander", True, render_mode=False)
85+
env = make_multi_env(num_cpu, id, wave, render_mode=False, reset_mode="random")
7586
else:
76-
env = make_env("LunarLander", True, render_mode=False, reset_mode="random")
87+
env = make_env(id, wave, render_mode=False, reset_mode="random")
7788

7889
if len(sys.argv) == 1:
7990
print("No model provided")

examples/stable-baselines/render_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from stable_baselines import A2C # noqa: E402
1010

1111

12-
render_episodes = 20
1312
wave = True
14-
multi = True
13+
render_episodes = 20
1514

1615
if len(sys.argv) < 2:
1716
print("USAGE: {} PATH-TO-MODEL-FILE".format(sys.argv[0]))
@@ -23,7 +22,7 @@
2322
exit(-1)
2423

2524
id = "LunarLander" if wave else "LunarLander-v2"
26-
env = make_env(id, wave, port=4000)
25+
env = make_env(id, wave, port=4000, reset_mode="random")
2726

2827
model = A2C.load(model_path)
2928

examples/stable-baselines/train.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,59 +5,65 @@
55

66
import os # noqa: E402
77
import time # noqa: E402
8+
from pathlib import Path # noqa: E402
89
from common.utils import make_env, make_multi_env # noqa: E402
910
from common.callbacks import save_callback # noqa: E402
1011
from stable_baselines.common.policies import MlpPolicy # noqa: E402
1112
from stable_baselines import A2C # noqa: E402
1213

1314

1415
# CONFIG
15-
model_path = "models/wave_example_a2c/"
1616
wave = True
1717
label = "a2c_example"
1818
order = int(1e3)
1919
order_str = "K"
2020
learn_timesteps = 24 * order
2121
save_interval = 2 * order
2222
num_cpu = 12
23-
log_dir = "logs"
23+
models_dir = "./models"
24+
log_dir = "./logs"
2425

2526
if __name__ == "__main__":
26-
try:
27-
os.mkdir(model_path)
28-
except FileExistsError:
29-
pass
30-
try:
31-
os.mkdir(log_dir)
32-
except FileExistsError:
33-
pass
27+
# e.g.: ./models/wave_a2c_example/
28+
model_path = Path(models_dir).joinpath(
29+
"{}_{}".format("wave" if wave else "gym", label)
30+
)
31+
# e.g.: ./logs
32+
log_path = Path(log_dir)
33+
34+
# create folders
35+
model_path.mkdir(parents=True, exist_ok=True)
36+
log_path.mkdir(exist_ok=True)
3437

38+
# create the wave or gym environment, with or without multiprocessing
3539
id = "LunarLander" if wave else "LunarLander-v2"
3640
if num_cpu > 1:
3741
env = make_multi_env(num_cpu, id, wave, render_mode=False, reset_mode="random")
3842
else:
39-
env = make_env(id, wave, render_mode=False)
43+
env = make_env(id, wave, render_mode=False, reset_mode="random")
4044

45+
# create A2C with Mlp policy, and the callback to save snapshots
4146
model = A2C(MlpPolicy, env, ent_coef=0.1, verbose=0, tensorboard_log=log_dir)
4247
callback = save_callback(
4348
model_path,
44-
"model-",
49+
"snapshot-",
4550
save_interval,
4651
call_interval=model.n_steps * num_cpu,
4752
order=order,
4853
order_str=order_str,
4954
)
5055

56+
# save final model
5157
print("Training...")
5258
_t = time.time()
5359
model.learn(total_timesteps=learn_timesteps, callback=callback)
5460
t = time.time() - _t
5561
str_t = time.strftime("%H h, %M m, %S s", time.gmtime(t))
5662
print("Trained in {} during {} timesteps".format(str_t, learn_timesteps))
5763

58-
final_model = model_path + "-{}{}-final".format(
59-
int(learn_timesteps / order), order_str
64+
final_model = model_path.joinpath(
65+
"{}{}-final".format(int(learn_timesteps / order), order_str)
6066
)
61-
model.save(final_model)
67+
model.save(str(final_model))
6268

6369
env.close()

0 commit comments

Comments
 (0)