Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions zoo/jericho/configs/jericho_unizero_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,20 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
gpu_num = 4
collector_env_num: int = 4 # Number of collector environments
n_episode = int(collector_env_num*gpu_num)
batch_size = int(1*gpu_num)
accumulation_steps=1

# Model name or path - configurable according to the predefined model paths or names
encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder

if encoder_option == 'qwen':
model_name: str = 'Qwen/Qwen3-0.6B'
batch_size = int(1*gpu_num)
accumulation_steps=64
elif encoder_option == 'legacy':
model_name: str = 'BAAI/bge-base-en-v1.5'
batch_size = int(64*gpu_num)
accumulation_steps=1
else:
raise ValueError(f"Unsupported encoder option: {encoder_option}")

# TODO
# batch_size = batch_size * 2
Expand Down Expand Up @@ -62,16 +74,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
# reanalyze_partition: Partition ratio from the replay buffer to use during reanalysis
reanalyze_partition: float = 0.75

# Model name or path - configurable according to the predefined model paths or names
encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder

if encoder_option == 'qwen':
model_name: str = 'Qwen/Qwen3-0.6B'
elif encoder_option == 'legacy':
model_name: str = 'BAAI/bge-base-en-v1.5'
else:
raise ValueError(f"Unsupported encoder option: {encoder_option}")

# ------------------------------------------------------------------
# TODO: Debug configuration - override some parameters for debugging purposes
# ------------------------------------------------------------------
Expand Down Expand Up @@ -136,7 +138,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
embed_dim=embed_dim,
obs_type="text", # TODO: Modify as needed.
env_num=max(collector_env_num, evaluator_env_num),
decode_loss_mode='None', # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None.
decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None.
latent_recon_loss_weight=0.1 # TODO: decoder loss weight
),
),
Expand All @@ -152,7 +154,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
cos_lr_scheduler=False,
fixed_temperature_value=0.25,
manual_temperature_decay=False,
# manual_temperature_decay=True,

num_simulations=num_simulations,
n_episode=n_episode,
Expand Down