Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def train_muzero(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_history', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'"

if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']:
if create_cfg.policy.type in ['muzero', 'muzero_history', 'muzero_context', 'muzero_rnn_full_obs']:
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
Expand Down
2 changes: 1 addition & 1 deletion lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def default_config(cls: type) -> EasyDict:
# (bool) Whether to use the root value in the reanalyzing part. Please refer to EfficientZero paper for details.
use_root_value=False,
# (int) The number of samples required for mini inference.
mini_infer_size=10240,
mini_infer_size=20480,
# (str) The type of sampled data. The default is 'transition'. Options: 'transition', 'episode'.
sample_type='transition',
)
Expand Down
57 changes: 40 additions & 17 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def __init__(self, cfg: dict):
self.sample_times = 0
self.active_root_num = 0

self.history_length = self._cfg.history_length
if self.history_length > 1:
self.num_unroll_steps = self._cfg.num_unroll_steps + self.history_length
else:
self.num_unroll_steps = self._cfg.num_unroll_steps


def reset_runtime_metrics(self):
"""
Overview:
Expand Down Expand Up @@ -138,6 +145,7 @@ def sample(
batch_size, self._cfg.reanalyze_ratio
)
# target reward, target value
# import ipdb;ipdb.set_trace()
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model
)
Expand Down Expand Up @@ -191,29 +199,29 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
pos_in_game_segment = pos_in_game_segment_list[i]

actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
self.num_unroll_steps].tolist()

# add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid
# mask_tmp = [1. for i in range(len(actions_tmp))]
# mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]
# mask_tmp += [0. for _ in range(self.num_unroll_steps + 1 - len(mask_tmp))]

# TODO: the child_visits after position <self._cfg.game_segment_length> in the segment (with padded part) may not be updated
# So the corresponding position should not be used in the training
mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))]
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]
mask_tmp += [0. for _ in range(self.num_unroll_steps + 1 - len(mask_tmp))]

# pad random action
actions_tmp += [
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
for _ in range(self.num_unroll_steps - len(actions_tmp))
]

# obtain the input observations
# pad if length of obs in game_segment is less than stack+num_unroll_steps
# e.g. stack+num_unroll_steps = 4+5
obs_list.append(
game_segment_list[i].get_unroll_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
pos_in_game_segment_list[i], num_unroll_steps=self.num_unroll_steps, padding=True
)
)
action_list.append(actions_tmp)
Expand Down Expand Up @@ -299,7 +307,7 @@ def _prepare_reward_value_context(
# prepare the corresponding observations for bootstrapped values o_{t+k}
# o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps]
# t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14]
game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps)
game_obs = game_segment.get_unroll_obs(state_index + td_steps, self.num_unroll_steps)

rewards_list.append(game_segment.reward_segment)

Expand All @@ -309,7 +317,7 @@ def _prepare_reward_value_context(

truncation_length = game_segment_len

for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
for current_index in range(state_index, state_index + self.num_unroll_steps + 1):
# get the <num_unroll_steps+1> bootstrapped target obs
td_steps_list.append(td_steps)
# index of bootstrapped obs o_{t+td_steps}
Expand Down Expand Up @@ -400,9 +408,9 @@ def _prepare_policy_reanalyzed_context(
child_visits.append(game_segment.child_visit_segment)
root_values.append(game_segment.root_value_segment)
# prepare the corresponding observations
game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps)
game_obs = game_segment.get_unroll_obs(state_index, self.num_unroll_steps)

for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
for current_index in range(state_index, state_index + self.num_unroll_steps + 1):

if current_index < game_segment_len: # original
policy_mask.append(1)
Expand Down Expand Up @@ -436,10 +444,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
# transition_batch_size = game_segment_batch_size * (num_unroll_steps+1)
transition_batch_size = len(value_obs_list)

# if self.history_length>1:
# game_segment_batch_size = len(pos_in_game_segment_list)
# transition_batch_size = transition_batch_size - (self.history_length-1)*game_segment_batch_size

batch_target_values, batch_rewards = [], []
with torch.no_grad():
# import ipdb;ipdb.set_trace()
value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)

if transition_batch_size > self._cfg.mini_infer_size:
print(f"transition_batch_size > mini_infer_size:{transition_batch_size > self._cfg.mini_infer_size}")
# split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
network_output = []
Expand All @@ -448,7 +462,12 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
end_index = self._cfg.mini_infer_size * (i + 1)
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device)
# calculate the target value
m_output = model.initial_inference(m_obs)
# import ipdb;ipdb.set_trace()
# print(f"m_obs.shape: {m_obs.shape}")
try:
m_output = model.initial_inference(m_obs)
except Exception as e:
print(e)

if not model.training:
# if not in training, obtain the scalars of the value/reward
Expand All @@ -469,6 +488,8 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
# use the predicted values
value_list = concat_output_value(network_output)

# print(f"value_list.shape: {value_list.shape}")

# get last state value
if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
# TODO(pu): for board_games, very important, to check
Expand Down Expand Up @@ -498,7 +519,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A

truncation_length = game_segment_len_non_re

for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
for current_index in range(state_index, state_index + self.num_unroll_steps + 1):
bootstrap_index = current_index + td_steps_list[value_index]
for i, reward in enumerate(reward_list[current_index:bootstrap_index]):
if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
Expand Down Expand Up @@ -544,7 +565,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
# for board games
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens, action_mask_segment, \
to_play_segment = policy_re_context
# transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
# transition_batch_size = game_segment_batch_size * (self.num_unroll_steps + 1)
transition_batch_size = len(policy_obs_list)
game_segment_batch_size = len(pos_in_game_segment_list)

Expand Down Expand Up @@ -623,7 +644,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
for state_index, child_visit, game_index in zip(pos_in_game_segment_list, child_visits, batch_index_list):
target_policies = []

for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
for current_index in range(state_index, state_index + self.num_unroll_steps + 1):
distributions = roots_distributions[policy_index]
searched_value = roots_values[policy_index]

Expand Down Expand Up @@ -694,10 +715,10 @@ def _compute_target_policy_non_reanalyzed(

pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context
game_segment_batch_size = len(pos_in_game_segment_list)
transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
transition_batch_size = game_segment_batch_size * (self.num_unroll_steps + 1)

to_play, action_mask = self._preprocess_to_play_and_action_mask(
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, self.num_unroll_steps
)

if self._cfg.model.continuous_action_space is True:
Expand All @@ -710,7 +731,9 @@ def _compute_target_policy_non_reanalyzed(
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
]
else:
# import ipdb;ipdb.set_trace()
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
# legal_actions = None

with torch.no_grad():
policy_index = 0
Expand All @@ -721,7 +744,7 @@ def _compute_target_policy_non_reanalyzed(
pos_in_game_segment_list):
target_policies = []

for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
for current_index in range(state_index, state_index + self.num_unroll_steps + 1):
if current_index < game_segment_len:
policy_mask.append(1)
# NOTE: child_visit is already a distribution
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
self.zero_obs_shape = config.model.observation_shape
elif len(config.model.observation_shape) == 3:
# image obs input, e.g. atari environments
# print(f'NOTE: config.model.image_channel:{config.model.image_channel}')
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

self.obs_segment = []
Expand Down
11 changes: 9 additions & 2 deletions lzero/mcts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,18 @@ def prepare_observation(observation_list, model_type='conv'):
Returns:
- np.ndarray: Reshaped array of observations.
"""
assert model_type in ['conv', 'mlp', 'conv_context', 'mlp_context'], "model_type must be either 'conv' or 'mlp'"
assert model_type in ['conv', 'conv_history', 'mlp', 'conv_context', 'mlp_context'], "model_type must be either 'conv' or 'mlp'"
observation_array = np.array(observation_list)

# try:
# observation_array = np.array(observation_list)
# except Exception as e:
# print(e)
# import ipdb;ipdb.set_trace()

batch_size = observation_array.shape[0]

if model_type in ['conv', 'conv_context']:
if model_type in ['conv', 'conv_history', 'conv_context']:
if observation_array.ndim == 3:
# Add a channel dimension if it's missing
observation_array = observation_array[..., np.newaxis]
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ def __init__(
if observation_shape[1] == 96:
latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16)
elif observation_shape[1] == 64:
latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8)
latent_shape = (int(observation_shape[1] / 8), int(observation_shape[2] / 8))

if norm_type == 'BN':
self.norm_value = nn.BatchNorm2d(value_head_channels)
Expand Down
Loading
Loading