Configuration
Configuration dataclasses for SkyRL training.
Top-Level Config
The root configuration object and helpers.
class SkyRLTrainConfig
SkyRLTrainConfig(data: DataConfig = DataConfig(), trainer: TrainerConfig = TrainerConfig(), generator: GeneratorConfig = GeneratorConfig(), environment: EnvironmentConfig = EnvironmentConfig()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
from_cli_overrides | Construct a SkyRLTrainConfig from CLI arguments or a dict of overrides. |
Attributes:
| Name | Type | Description |
|---|---|---|
data | DataConfig | |
trainer | TrainerConfig | |
generator | GeneratorConfig | |
environment | EnvironmentConfig |
Source code in skyrl/train/config/config.py:694-787
@dataclass
class SkyRLTrainConfig(BaseConfig):
data: DataConfig = field(default_factory=DataConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
generator: GeneratorConfig = field(default_factory=GeneratorConfig)
environment: EnvironmentConfig = field(default_factory=EnvironmentConfig)
def __post_init__(self):
# generator.max_input_length defaults to trainer.max_prompt_length
if self.generator.max_input_length is None:
self.generator.max_input_length = self.trainer.max_prompt_length
# generator rope params default to trainer rope params
if self.generator.rope_scaling is None and self.trainer.rope_scaling is not None:
self.generator.rope_scaling = self.trainer.rope_scaling
if self.generator.rope_theta is None and self.trainer.rope_theta is not None:
self.generator.rope_theta = self.trainer.rope_theta
# Copy temperature from generator sampling params to algorithm config
# so workers can access it without needing the generator config
if self.trainer.algorithm.temperature is None:
self.trainer.algorithm.temperature = self.generator.sampling_params.temperature
if self.trainer.algorithm.max_seq_len is None:
# NOTE (erictang000): this is the max sequence length including the prompt, since max response length
# per batch can be variable based on the prompt length. This is used to normalize the loss for
# seq_mean_token_sum_norm loss reduction.
# TODO(Charlie): This calculation is not correct for multi-turn and users should use `max_seq_len` instead.
# Should we just force users to set max_seq_len if loss reduction is seq_mean_token_sum_norm, regardless of
# multi-turn or not?
self.trainer.algorithm.max_seq_len = (
self.generator.max_input_length + self.generator.sampling_params.max_generate_length
)
@classmethod
def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SkyRLTrainConfig":
"""Construct a SkyRLTrainConfig from CLI arguments or a dict of overrides.
Parses CLI arguments and builds a typed config. Dataclass field defaults
are used for any values not specified on the command line.
Supports both new-style config paths (e.g., generator.inference_engine.backend)
and legacy YAML-style paths (e.g., generator.backend) for backward compatibility.
Args:
args: Either a list of CLI arguments in 'key.path=value' format, or a dict
mapping dot-notation keys to values.
Example list: ['trainer.policy.model.path=Qwen/Qwen2.5-1.5B-Instruct', 'trainer.seed=123']
Example dict: {'trainer.policy.model.path': 'Qwen/Qwen2.5-1.5B-Instruct', 'trainer.seed': 123}
Returns:
A fully constructed SkyRLTrainConfig with CLI overrides applied.
Raises:
ValueError: If an argument uses the unsupported '+' prefix.
"""
if isinstance(args, dict):
args = [f"{k}={v}" for k, v in args.items()]
from skyrl.train.config.legacy import (
is_legacy_config,
translate_legacy_config,
warn_legacy_config,
)
from skyrl.train.config.utils import get_legacy_config
# Check for unsupported '+' prefix
for arg in args:
if arg.startswith("+"):
raise ValueError(
f"The '+' prefix for adding new config fields is not supported: '{arg}'. "
"To add custom config fields, subclass the relevant config dataclass."
)
overrides = OmegaConf.from_cli(args)
# Try new format first
try:
return cls.from_dict_config(overrides)
except ValueError:
# Fall back to legacy format: load base YAML, merge overrides, translate
try:
base_cfg = get_legacy_config()
merged = OmegaConf.merge(base_cfg, overrides)
merged_dict = OmegaConf.to_container(merged, resolve=True)
if is_legacy_config(merged_dict):
warn_legacy_config()
translated = translate_legacy_config(merged_dict)
return build_nested_dataclass(cls, translated)
except Exception:
pass # Legacy translation failed, re-raise original error
# Re-raise original error if not a legacy config issue
raisefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr data
data: DataConfig = field(default_factory=DataConfig)attr trainer
trainer: TrainerConfig = field(default_factory=TrainerConfig)attr generator
generator: GeneratorConfig = field(default_factory=GeneratorConfig)attr environment
environment: EnvironmentConfig = field(default_factory=EnvironmentConfig)method classmethod from_cli_overrides
from_cli_overrides(args: Union[List[str], dict]) -> SkyRLTrainConfigConstruct a SkyRLTrainConfig from CLI arguments or a dict of overrides.
Parses CLI arguments and builds a typed config. Dataclass field defaults are used for any values not specified on the command line.
Supports both new-style config paths (e.g., generator.inference_engine.backend) and legacy YAML-style paths (e.g., generator.backend) for backward compatibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
args | Union[List[str], dict] | Either a list of CLI arguments in 'key.path=value' format, or a dict mapping dot-notation keys to values. Example list: ['trainer.policy.model.path=Qwen/Qwen2.5-1.5B-Instruct', 'trainer.seed=123'] Example dict: {'trainer.policy.model.path': 'Qwen/Qwen2.5-1.5B-Instruct', 'trainer.seed': 123} | required |
Returns:
| Type | Description |
|---|---|
| SkyRLTrainConfig | A fully constructed SkyRLTrainConfig with CLI overrides applied. |
Raises:
| Type | Description |
|---|---|
| ValueError | If an argument uses the unsupported '+' prefix. |
Source code in skyrl/train/config/config.py:728-787
@classmethod
def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SkyRLTrainConfig":
"""Construct a SkyRLTrainConfig from CLI arguments or a dict of overrides.
Parses CLI arguments and builds a typed config. Dataclass field defaults
are used for any values not specified on the command line.
Supports both new-style config paths (e.g., generator.inference_engine.backend)
and legacy YAML-style paths (e.g., generator.backend) for backward compatibility.
Args:
args: Either a list of CLI arguments in 'key.path=value' format, or a dict
mapping dot-notation keys to values.
Example list: ['trainer.policy.model.path=Qwen/Qwen2.5-1.5B-Instruct', 'trainer.seed=123']
Example dict: {'trainer.policy.model.path': 'Qwen/Qwen2.5-1.5B-Instruct', 'trainer.seed': 123}
Returns:
A fully constructed SkyRLTrainConfig with CLI overrides applied.
Raises:
ValueError: If an argument uses the unsupported '+' prefix.
"""
if isinstance(args, dict):
args = [f"{k}={v}" for k, v in args.items()]
from skyrl.train.config.legacy import (
is_legacy_config,
translate_legacy_config,
warn_legacy_config,
)
from skyrl.train.config.utils import get_legacy_config
# Check for unsupported '+' prefix
for arg in args:
if arg.startswith("+"):
raise ValueError(
f"The '+' prefix for adding new config fields is not supported: '{arg}'. "
"To add custom config fields, subclass the relevant config dataclass."
)
overrides = OmegaConf.from_cli(args)
# Try new format first
try:
return cls.from_dict_config(overrides)
except ValueError:
# Fall back to legacy format: load base YAML, merge overrides, translate
try:
base_cfg = get_legacy_config()
merged = OmegaConf.merge(base_cfg, overrides)
merged_dict = OmegaConf.to_container(merged, resolve=True)
if is_legacy_config(merged_dict):
warn_legacy_config()
translated = translate_legacy_config(merged_dict)
return build_nested_dataclass(cls, translated)
except Exception:
pass # Legacy translation failed, re-raise original error
# Re-raise original error if not a legacy config issue
raisemethod make_config
make_config(algorithm_cls: Optional[Type[AlgorithmConfig]] = None, trainer_cls: Optional[Type[TrainerConfig]] = None, generator_cls: Optional[Type[GeneratorConfig]] = None) -> Type[SkyRLTrainConfig]Create a SkyRLTrainConfig subclass with custom nested config classes.
Convenience helper to avoid boilerplate when extending configs for custom algorithms or generators. For full IDE autocomplete on custom fields, use explicit subclassing instead (see examples/algorithms/dapo/main_dapo.py).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm_cls | Optional[Type[AlgorithmConfig]] | Custom AlgorithmConfig subclass. If provided without trainer_cls, a TrainerConfig subclass is automatically created. | None |
trainer_cls | Optional[Type[TrainerConfig]] | Custom TrainerConfig subclass. Takes precedence over algorithm_cls for the trainer config. | None |
generator_cls | Optional[Type[GeneratorConfig]] | Custom GeneratorConfig subclass. | None |
Returns:
| Type | Description |
|---|---|
| Type[SkyRLTrainConfig] | A SkyRLTrainConfig subclass wired up with the custom config classes. |
Example::
@dataclass
class MyAlgorithmConfig(AlgorithmConfig):
my_param: int = 42
MyConfig = make_config(algorithm_cls=MyAlgorithmConfig)
cfg = MyConfig.from_cli_overrides(sys.argv[1:])method get_config_as_dict
get_config_as_dict(cfg: Union[dict, BaseConfig]) -> dictmethod get_config_as_yaml_str
get_config_as_yaml_str(cfg: BaseConfig) -> strData & Model
class DataConfig
DataConfig(train_data: List[str] = (lambda: [os.path.expanduser('~/data/gsm8k/train.parquet')])(), val_data: List[str] = (lambda: [os.path.expanduser('~/data/gsm8k/validation.parquet')])()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
train_data | List[str] | |
val_data | List[str] |
Source code in skyrl/train/config/config.py:40-43
@dataclass
class DataConfig(BaseConfig):
train_data: List[str] = field(default_factory=lambda: [os.path.expanduser("~/data/gsm8k/train.parquet")])
val_data: List[str] = field(default_factory=lambda: [os.path.expanduser("~/data/gsm8k/validation.parquet")])from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr train_data
train_data: List[str] = field(default_factory=(lambda: [os.path.expanduser('~/data/gsm8k/train.parquet')]))attr val_data
val_data: List[str] = field(default_factory=(lambda: [os.path.expanduser('~/data/gsm8k/validation.parquet')]))class ModelConfig
ModelConfig(path: Optional[str] = None, lora: SkyRLLoraConfig = SkyRLLoraConfig()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
path | Optional[str] | |
lora | SkyRLLoraConfig |
Source code in skyrl/train/config/config.py:65-68
@dataclass
class ModelConfig(BaseConfig):
path: Optional[str] = None
lora: SkyRLLoraConfig = field(default_factory=SkyRLLoraConfig)from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr path
path: Optional[str] = Noneattr lora
lora: SkyRLLoraConfig = field(default_factory=SkyRLLoraConfig)class SkyRLLoraConfig
SkyRLLoraConfig(rank: int = 0, alpha: int = 16, dropout: float = 0.0, lora_sync_path: str = '/tmp/skyrl_lora_sync', target_modules: str = 'all-linear', exclude_modules: Optional[str] = None, init_method: str = 'kaiming') -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
rank | int | |
alpha | int | |
dropout | float | |
lora_sync_path | str | |
target_modules | str | |
exclude_modules | Optional[str] | |
init_method | str | For FSDP, corresponds to init_lora_weights in PEFT. |
Source code in skyrl/train/config/config.py:52-62
@dataclass
class SkyRLLoraConfig(BaseConfig):
rank: int = 0
alpha: int = 16
dropout: float = 0.0
lora_sync_path: str = "/tmp/skyrl_lora_sync"
target_modules: str = "all-linear"
exclude_modules: Optional[str] = None
init_method: str = "kaiming"
"""For FSDP, corresponds to ``init_lora_weights`` in PEFT.
For Megatron, used for ``lora_A_init_method``; supports "xavier", "normal", "kaiming", "zero"."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr rank
rank: int = 0attr alpha
alpha: int = 16attr dropout
dropout: float = 0.0attr lora_sync_path
lora_sync_path: str = '/tmp/skyrl_lora_sync'attr target_modules
target_modules: str = 'all-linear'attr exclude_modules
exclude_modules: Optional[str] = Noneattr init_method
init_method: str = 'kaiming'For FSDP, corresponds to init_lora_weights in PEFT.
For Megatron, used for lora_A_init_method; supports "xavier", "normal", "kaiming", "zero".
Training
class TrainerConfig
TrainerConfig(placement: PlacementConfig = PlacementConfig(), sequence_parallel_backend: str = 'ulysses', strategy: str = 'fsdp2', policy: PolicyConfig = PolicyConfig(), ref: RefConfig = RefConfig(), critic: CriticConfig = CriticConfig(), algorithm: AlgorithmConfig = AlgorithmConfig(), fully_async: FullyAsyncConfig = FullyAsyncConfig(), gradient_checkpointing: bool = True, gradient_checkpointing_use_reentrant: bool = False, seed: int = 42, resume_mode: Optional[str] = 'latest', resume_path: Optional[str] = None, log_path: str = '/tmp/skyrl-logs', ckpt_path: str = (lambda: os.path.expanduser('~/ckpts/'))(), max_ckpts_to_keep: int = -1, ckpt_interval: int = 10, hf_save_interval: int = -1, export_path: str = (lambda: os.path.expanduser('~/exports/'))(), bf16: bool = True, epochs: int = 1, update_epochs_per_batch: int = 1, train_batch_size: int = 1024, policy_mini_batch_size: int = 256, critic_mini_batch_size: int = 256, micro_train_batch_size_per_gpu: int = 1, micro_forward_batch_size_per_gpu: int = 1, update_ref_every_epoch: bool = False, use_sample_packing: bool = True, eval_batch_size: int = 1024, eval_before_train: bool = True, eval_interval: int = 5, max_prompt_length: int = 512, flash_attn: bool = True, disable_fast_tokenizer: bool = False, project_name: str = 'skyrl', run_name: str = 'test_run', logger: str = 'wandb', dump_data_batch: bool = False, dump_eval_results: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, rope_theta: Optional[float] = None) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
Source code in skyrl/train/config/config.py:558-614
@dataclass
class TrainerConfig(BaseConfig):
placement: PlacementConfig = field(default_factory=PlacementConfig)
sequence_parallel_backend: str = "ulysses"
strategy: str = "fsdp2"
policy: PolicyConfig = field(default_factory=PolicyConfig)
ref: RefConfig = field(default_factory=RefConfig)
critic: CriticConfig = field(default_factory=CriticConfig)
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
fully_async: FullyAsyncConfig = field(default_factory=FullyAsyncConfig)
gradient_checkpointing: bool = True
gradient_checkpointing_use_reentrant: bool = False
seed: int = 42
resume_mode: Optional[str] = "latest"
"""``None``/``"none"``, ``"latest"``, or ``"from_path"``."""
resume_path: Optional[str] = None
log_path: str = "/tmp/skyrl-logs"
"""Path for infrastructure log files. For multi-node, use a shared filesystem path to consolidate logs."""
ckpt_path: str = field(default_factory=lambda: os.path.expanduser("~/ckpts/"))
max_ckpts_to_keep: int = -1
"""``-1`` to keep all checkpoints, ``N`` to keep only the last N."""
ckpt_interval: int = 10
hf_save_interval: int = -1
"""Save HuggingFace-format model every N steps. ``-1`` to disable."""
export_path: str = field(default_factory=lambda: os.path.expanduser("~/exports/"))
"""Path for exported artifacts (HF models, debug dumps, etc.)."""
bf16: bool = True
epochs: int = 1
update_epochs_per_batch: int = 1
"""Number of gradient update passes over each training batch."""
train_batch_size: int = 1024
"""See ``utils/utils.py::validate_batch_sizes`` for train, mini, and micro batch size constraints."""
policy_mini_batch_size: int = 256
critic_mini_batch_size: int = 256
micro_train_batch_size_per_gpu: int = 1
micro_forward_batch_size_per_gpu: int = 1
update_ref_every_epoch: bool = False
use_sample_packing: bool = True
eval_batch_size: int = 1024
eval_before_train: bool = True
eval_interval: int = 5
"""``-1`` to disable evaluation."""
max_prompt_length: int = 512
flash_attn: bool = True
disable_fast_tokenizer: bool = False
project_name: str = "skyrl"
run_name: str = "test_run"
logger: str = "wandb"
dump_data_batch: bool = False
dump_eval_results: bool = True
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None
def __post_init__(self):
# ref model defaults to the policy model
if self.ref.model.path is None:
self.ref.model.path = self.policy.model.pathfrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr placement
placement: PlacementConfig = field(default_factory=PlacementConfig)attr sequence_parallel_backend
sequence_parallel_backend: str = 'ulysses'attr strategy
strategy: str = 'fsdp2'attr policy
policy: PolicyConfig = field(default_factory=PolicyConfig)attr ref
ref: RefConfig = field(default_factory=RefConfig)attr critic
critic: CriticConfig = field(default_factory=CriticConfig)attr algorithm
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)attr fully_async
fully_async: FullyAsyncConfig = field(default_factory=FullyAsyncConfig)attr gradient_checkpointing
gradient_checkpointing: bool = Trueattr gradient_checkpointing_use_reentrant
gradient_checkpointing_use_reentrant: bool = Falseattr seed
seed: int = 42attr resume_mode
resume_mode: Optional[str] = 'latest'None/"none", "latest", or "from_path".
attr resume_path
resume_path: Optional[str] = Noneattr log_path
log_path: str = '/tmp/skyrl-logs'Path for infrastructure log files. For multi-node, use a shared filesystem path to consolidate logs.
attr ckpt_path
ckpt_path: str = field(default_factory=(lambda: os.path.expanduser('~/ckpts/')))attr max_ckpts_to_keep
max_ckpts_to_keep: int = -1-1 to keep all checkpoints, N to keep only the last N.
attr ckpt_interval
ckpt_interval: int = 10attr hf_save_interval
hf_save_interval: int = -1Save HuggingFace-format model every N steps. -1 to disable.
attr export_path
export_path: str = field(default_factory=(lambda: os.path.expanduser('~/exports/')))Path for exported artifacts (HF models, debug dumps, etc.).
attr bf16
bf16: bool = Trueattr epochs
epochs: int = 1attr update_epochs_per_batch
update_epochs_per_batch: int = 1Number of gradient update passes over each training batch.
attr train_batch_size
train_batch_size: int = 1024See utils/utils.py::validate_batch_sizes for train, mini, and micro batch size constraints.
attr policy_mini_batch_size
policy_mini_batch_size: int = 256attr critic_mini_batch_size
critic_mini_batch_size: int = 256attr micro_train_batch_size_per_gpu
micro_train_batch_size_per_gpu: int = 1attr micro_forward_batch_size_per_gpu
micro_forward_batch_size_per_gpu: int = 1attr update_ref_every_epoch
update_ref_every_epoch: bool = Falseattr use_sample_packing
use_sample_packing: bool = Trueattr eval_batch_size
eval_batch_size: int = 1024attr eval_before_train
eval_before_train: bool = Trueattr eval_interval
eval_interval: int = 5-1 to disable evaluation.
attr max_prompt_length
max_prompt_length: int = 512attr flash_attn
flash_attn: bool = Trueattr disable_fast_tokenizer
disable_fast_tokenizer: bool = Falseattr project_name
project_name: str = 'skyrl'attr run_name
run_name: str = 'test_run'attr logger
logger: str = 'wandb'attr dump_data_batch
dump_data_batch: bool = Falseattr dump_eval_results
dump_eval_results: bool = Trueattr rope_scaling
rope_scaling: Optional[Dict[str, Any]] = Noneattr rope_theta
rope_theta: Optional[float] = Noneclass OptimizerConfig
OptimizerConfig(lr: float = 1e-06, adam_betas: List[float] = (lambda: [0.9, 0.999])(), weight_decay: float = 0.01, max_grad_norm: float = 1.0, offload_after_step: bool = True, num_warmup_steps: int = 0, scheduler: str = 'constant_with_warmup') -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
lr | float | |
adam_betas | List[float] | |
weight_decay | float | |
max_grad_norm | float | |
offload_after_step | bool | Offload optimizer state to CPU after each full training step. Only applicable when colocate_all=True. |
num_warmup_steps | int | Number of mini-batch steps to warmup the optimizer. |
scheduler | str |
Source code in skyrl/train/config/config.py:76-86
@dataclass
class OptimizerConfig(BaseConfig):
lr: float = 1e-6
adam_betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
weight_decay: float = 1e-2
max_grad_norm: float = 1.0
offload_after_step: bool = True
"""Offload optimizer state to CPU after each full training step. Only applicable when ``colocate_all=True``."""
num_warmup_steps: int = 0
"""Number of mini-batch steps to warmup the optimizer."""
scheduler: str = "constant_with_warmup"from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr lr
lr: float = 1e-06attr adam_betas
adam_betas: List[float] = field(default_factory=(lambda: [0.9, 0.999]))attr weight_decay
weight_decay: float = 0.01attr max_grad_norm
max_grad_norm: float = 1.0attr offload_after_step
offload_after_step: bool = TrueOffload optimizer state to CPU after each full training step. Only applicable when colocate_all=True.
attr num_warmup_steps
num_warmup_steps: int = 0Number of mini-batch steps to warmup the optimizer.
attr scheduler
scheduler: str = 'constant_with_warmup'class MixedPrecisionConfig
MixedPrecisionConfig(param_dtype: str = 'bf16', reduce_dtype: str = 'fp32', buffer_dtype: str = 'fp32') -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
param_dtype | str | |
reduce_dtype | str | |
buffer_dtype | str |
Source code in skyrl/train/config/config.py:89-93
@dataclass
class MixedPrecisionConfig(BaseConfig):
param_dtype: str = "bf16"
reduce_dtype: str = "fp32"
buffer_dtype: str = "fp32"from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr param_dtype
param_dtype: str = 'bf16'attr reduce_dtype
reduce_dtype: str = 'fp32'attr buffer_dtype
buffer_dtype: str = 'fp32'Backend Config
class FSDPConfig
FSDPConfig(cpu_offload: bool = False, reshard_after_forward: Union[bool, int] = True, fsdp_size: int = -1, mixed_precision: Optional[MixedPrecisionConfig] = None, wrap_policy: dict = dict()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
cpu_offload | bool | Offload params and optimizer state to CPU during the forward pass. |
reshard_after_forward | Union[bool, int] | FSDP2 only. Accepts True, False, or an int between 1 and fsdp_size. |
fsdp_size | int | |
mixed_precision | Optional[MixedPrecisionConfig] | |
wrap_policy | dict |
Source code in skyrl/train/config/config.py:96-105
@dataclass
class FSDPConfig(BaseConfig):
cpu_offload: bool = False
"""Offload params and optimizer state to CPU during the forward pass."""
reshard_after_forward: Union[bool, int] = True
"""FSDP2 only. Accepts True, False, or an int between 1 and ``fsdp_size``."""
fsdp_size: int = -1
mixed_precision: Optional[MixedPrecisionConfig] = None
# specify wrap policy as a dict with `transformer_layer_cls_to_wrap` key for custom module based wrapping
wrap_policy: dict = field(default_factory=dict)from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr cpu_offload
cpu_offload: bool = FalseOffload params and optimizer state to CPU during the forward pass.
attr reshard_after_forward
reshard_after_forward: Union[bool, int] = TrueFSDP2 only. Accepts True, False, or an int between 1 and fsdp_size.
attr fsdp_size
fsdp_size: int = -1attr mixed_precision
mixed_precision: Optional[MixedPrecisionConfig] = Noneattr wrap_policy
wrap_policy: dict = field(default_factory=dict)class MegatronConfig
MegatronConfig(tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, context_parallel_size: int = 1, expert_model_parallel_size: int = 1, expert_tensor_parallel_size: Optional[int] = None, moe_token_dispatcher_type: str = 'alltoall', moe_router_load_balancing_type: str = 'none', moe_grouped_gemm: bool = True, moe_router_score_function: Optional[str] = None, moe_router_enable_expert_bias: Optional[bool] = None, ddp_config: MegatronDDPConfig = MegatronDDPConfig(), torch_profiler_config: MegatronTorchProfilerConfig = MegatronTorchProfilerConfig(), lora_config: MegatronLoraConfig = MegatronLoraConfig(), optimizer_config_kwargs: Dict[str, Any] = (lambda: copy.deepcopy(DEFAULT_MEGATRON_OPTIMIZER_KWARGS))(), transformer_config_kwargs: Dict[str, Any] = (lambda: copy.deepcopy(DEFAULT_TRANSFORMER_CONFIG_KWARGS))(), empty_cuda_cache: Optional[bool] = None, model_config_kwargs: dict = dict(), dist_ckpt_optim_fully_reshardable: bool = False) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
tensor_model_parallel_size | int | |
pipeline_model_parallel_size | int | |
context_parallel_size | int | |
expert_model_parallel_size | int | |
expert_tensor_parallel_size | Optional[int] | |
moe_token_dispatcher_type | str | |
moe_router_load_balancing_type | str | |
moe_grouped_gemm | bool | |
moe_router_score_function | Optional[str] | |
moe_router_enable_expert_bias | Optional[bool] | |
ddp_config | MegatronDDPConfig | |
torch_profiler_config | MegatronTorchProfilerConfig | |
lora_config | MegatronLoraConfig | |
optimizer_config_kwargs | Dict[str, Any] | |
transformer_config_kwargs | Dict[str, Any] | |
empty_cuda_cache | Optional[bool] | |
model_config_kwargs | dict | |
dist_ckpt_optim_fully_reshardable | bool |
Source code in skyrl/train/config/config.py:148-172
@dataclass
class MegatronConfig(BaseConfig):
tensor_model_parallel_size: int = 1
pipeline_model_parallel_size: int = 1
context_parallel_size: int = 1
expert_model_parallel_size: int = 1
expert_tensor_parallel_size: Optional[int] = None
# MoE runtime configuration flags
moe_token_dispatcher_type: str = "alltoall"
moe_router_load_balancing_type: str = "none"
moe_grouped_gemm: bool = True
moe_router_score_function: Optional[str] = None
moe_router_enable_expert_bias: Optional[bool] = None
ddp_config: MegatronDDPConfig = field(default_factory=MegatronDDPConfig)
torch_profiler_config: MegatronTorchProfilerConfig = field(default_factory=MegatronTorchProfilerConfig)
lora_config: MegatronLoraConfig = field(default_factory=MegatronLoraConfig)
optimizer_config_kwargs: Dict[str, Any] = field(
default_factory=lambda: copy.deepcopy(DEFAULT_MEGATRON_OPTIMIZER_KWARGS)
)
transformer_config_kwargs: Dict[str, Any] = field(
default_factory=lambda: copy.deepcopy(DEFAULT_TRANSFORMER_CONFIG_KWARGS)
)
empty_cuda_cache: Optional[bool] = None
model_config_kwargs: dict = field(default_factory=dict)
dist_ckpt_optim_fully_reshardable: bool = Falsefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr tensor_model_parallel_size
tensor_model_parallel_size: int = 1attr pipeline_model_parallel_size
pipeline_model_parallel_size: int = 1attr context_parallel_size
context_parallel_size: int = 1attr expert_model_parallel_size
expert_model_parallel_size: int = 1attr expert_tensor_parallel_size
expert_tensor_parallel_size: Optional[int] = Noneattr moe_token_dispatcher_type
moe_token_dispatcher_type: str = 'alltoall'attr moe_router_load_balancing_type
moe_router_load_balancing_type: str = 'none'attr moe_grouped_gemm
moe_grouped_gemm: bool = Trueattr moe_router_score_function
moe_router_score_function: Optional[str] = Noneattr moe_router_enable_expert_bias
moe_router_enable_expert_bias: Optional[bool] = Noneattr ddp_config
ddp_config: MegatronDDPConfig = field(default_factory=MegatronDDPConfig)attr torch_profiler_config
torch_profiler_config: MegatronTorchProfilerConfig = field(default_factory=MegatronTorchProfilerConfig)attr lora_config
lora_config: MegatronLoraConfig = field(default_factory=MegatronLoraConfig)attr optimizer_config_kwargs
optimizer_config_kwargs: Dict[str, Any] = field(default_factory=(lambda: copy.deepcopy(DEFAULT_MEGATRON_OPTIMIZER_KWARGS)))attr transformer_config_kwargs
transformer_config_kwargs: Dict[str, Any] = field(default_factory=(lambda: copy.deepcopy(DEFAULT_TRANSFORMER_CONFIG_KWARGS)))attr empty_cuda_cache
empty_cuda_cache: Optional[bool] = Noneattr model_config_kwargs
model_config_kwargs: dict = field(default_factory=dict)attr dist_ckpt_optim_fully_reshardable
dist_ckpt_optim_fully_reshardable: bool = Falseclass MegatronDDPConfig
MegatronDDPConfig(grad_reduce_in_fp32: bool = True, overlap_grad_reduce: bool = False, overlap_param_gather: bool = False, average_in_collective: bool = True) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
grad_reduce_in_fp32 | bool | |
overlap_grad_reduce | bool | |
overlap_param_gather | bool | |
average_in_collective | bool |
Source code in skyrl/train/config/config.py:113-118
@dataclass
class MegatronDDPConfig(BaseConfig):
grad_reduce_in_fp32: bool = True
overlap_grad_reduce: bool = False
overlap_param_gather: bool = False
average_in_collective: bool = Truefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr grad_reduce_in_fp32
grad_reduce_in_fp32: bool = Trueattr overlap_grad_reduce
overlap_grad_reduce: bool = Falseattr overlap_param_gather
overlap_param_gather: bool = Falseattr average_in_collective
average_in_collective: bool = Trueclass MegatronLoraConfig
MegatronLoraConfig(lora_type: str = 'lora') -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
lora_type | str |
Source code in skyrl/train/config/config.py:128-130
@dataclass
class MegatronLoraConfig(BaseConfig):
lora_type: str = "lora"from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr lora_type
lora_type: str = 'lora'class MegatronTorchProfilerConfig
MegatronTorchProfilerConfig(enable: bool = False, ranks: List[int] = list(), save_path: Optional[str] = None) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
Source code in skyrl/train/config/config.py:121-125
@dataclass
class MegatronTorchProfilerConfig(BaseConfig):
enable: bool = False
ranks: List[int] = field(default_factory=list)
save_path: Optional[str] = Nonefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr enable
enable: bool = Falseattr ranks
ranks: List[int] = field(default_factory=list)attr save_path
save_path: Optional[str] = NonePlacement
class PlacementConfig
PlacementConfig(colocate_all: bool = True, colocate_policy_ref: bool = True, policy_num_nodes: int = 1, policy_num_gpus_per_node: int = 1, critic_num_nodes: int = 1, critic_num_gpus_per_node: int = 1, ref_num_nodes: int = 1, ref_num_gpus_per_node: int = 1) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
colocate_all | bool | When True, training and inference share the same GPUs. |
colocate_policy_ref | bool | |
policy_num_nodes | int | |
policy_num_gpus_per_node | int | |
critic_num_nodes | int | |
critic_num_gpus_per_node | int | |
ref_num_nodes | int | |
ref_num_gpus_per_node | int |
Source code in skyrl/train/config/config.py:180-190
@dataclass
class PlacementConfig(BaseConfig):
colocate_all: bool = True
"""When True, training and inference share the same GPUs."""
colocate_policy_ref: bool = True
policy_num_nodes: int = 1
policy_num_gpus_per_node: int = 1
critic_num_nodes: int = 1
critic_num_gpus_per_node: int = 1
ref_num_nodes: int = 1
ref_num_gpus_per_node: int = 1from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr colocate_all
colocate_all: bool = TrueWhen True, training and inference share the same GPUs.
attr colocate_policy_ref
colocate_policy_ref: bool = Trueattr policy_num_nodes
policy_num_nodes: int = 1attr policy_num_gpus_per_node
policy_num_gpus_per_node: int = 1attr critic_num_nodes
critic_num_nodes: int = 1attr critic_num_gpus_per_node
critic_num_gpus_per_node: int = 1attr ref_num_nodes
ref_num_nodes: int = 1attr ref_num_gpus_per_node
ref_num_gpus_per_node: int = 1Policy & Algorithm
class PolicyConfig
PolicyConfig(model: ModelConfig = (lambda: copy.deepcopy(ModelConfig(path='Qwen/Qwen2.5-1.5B-Instruct')))(), optimizer_config: OptimizerConfig = OptimizerConfig(), fsdp_config: FSDPConfig = FSDPConfig(), sequence_parallel_size: int = 1, use_torch_compile: bool = False, record_memory: bool = False, megatron_config: MegatronConfig = MegatronConfig(), model_config_kwargs: dict = dict()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
model | ModelConfig | |
optimizer_config | OptimizerConfig | |
fsdp_config | FSDPConfig | |
sequence_parallel_size | int | |
use_torch_compile | bool | Apply torch.compile to logits calculation. |
record_memory | bool | Save memory snapshots to {ckpt_path}/memory_snapshots/. |
megatron_config | MegatronConfig | |
model_config_kwargs | dict | Pass-through kwargs for the HuggingFace model config (FSDP backends). |
Source code in skyrl/train/config/config.py:198-212
@dataclass
class PolicyConfig(BaseConfig):
model: ModelConfig = field(default_factory=lambda: copy.deepcopy(ModelConfig(path="Qwen/Qwen2.5-1.5B-Instruct")))
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
sequence_parallel_size: int = 1
use_torch_compile: bool = False
"""Apply torch.compile to logits calculation."""
record_memory: bool = False
"""Save memory snapshots to ``{ckpt_path}/memory_snapshots/``.
Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz."""
megatron_config: MegatronConfig = field(default_factory=MegatronConfig)
model_config_kwargs: dict = field(default_factory=dict)
"""Pass-through kwargs for the HuggingFace model config (FSDP backends).
For Megatron, use ``policy.megatron_config.transformer_config_kwargs`` instead."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr model
model: ModelConfig = field(default_factory=(lambda: copy.deepcopy(ModelConfig(path='Qwen/Qwen2.5-1.5B-Instruct'))))attr optimizer_config
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)attr fsdp_config
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)attr sequence_parallel_size
sequence_parallel_size: int = 1attr use_torch_compile
use_torch_compile: bool = FalseApply torch.compile to logits calculation.
attr record_memory
record_memory: bool = FalseSave memory snapshots to {ckpt_path}/memory_snapshots/.
Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz.
attr megatron_config
megatron_config: MegatronConfig = field(default_factory=MegatronConfig)attr model_config_kwargs
model_config_kwargs: dict = field(default_factory=dict)Pass-through kwargs for the HuggingFace model config (FSDP backends).
For Megatron, use policy.megatron_config.transformer_config_kwargs instead.
class CriticConfig
CriticConfig(model: ModelConfig = ModelConfig(), optimizer_config: OptimizerConfig = (lambda: OptimizerConfig(lr=5e-06))(), fsdp_config: FSDPConfig = FSDPConfig(), sequence_parallel_size: int = 1, model_config_kwargs: dict = dict()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
model | ModelConfig | |
optimizer_config | OptimizerConfig | |
fsdp_config | FSDPConfig | |
sequence_parallel_size | int | |
model_config_kwargs | dict |
Source code in skyrl/train/config/config.py:215-221
@dataclass
class CriticConfig(BaseConfig):
model: ModelConfig = field(default_factory=ModelConfig)
optimizer_config: OptimizerConfig = field(default_factory=lambda: OptimizerConfig(lr=5e-6))
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
sequence_parallel_size: int = 1
model_config_kwargs: dict = field(default_factory=dict)from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr model
model: ModelConfig = field(default_factory=ModelConfig)attr optimizer_config
optimizer_config: OptimizerConfig = field(default_factory=(lambda: OptimizerConfig(lr=5e-06)))attr fsdp_config
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)attr sequence_parallel_size
sequence_parallel_size: int = 1attr model_config_kwargs
model_config_kwargs: dict = field(default_factory=dict)class RefConfig
RefConfig(model: ModelConfig = ModelConfig(), sequence_parallel_size: int = 1, fsdp_config: FSDPConfig = FSDPConfig(), megatron_config: MegatronConfig = MegatronConfig(), model_config_kwargs: dict = dict()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
model | ModelConfig | |
sequence_parallel_size | int | |
fsdp_config | FSDPConfig | |
megatron_config | MegatronConfig | |
model_config_kwargs | dict |
Source code in skyrl/train/config/config.py:225-231
@dataclass
class RefConfig(BaseConfig):
model: ModelConfig = field(default_factory=ModelConfig)
sequence_parallel_size: int = 1
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
megatron_config: MegatronConfig = field(default_factory=MegatronConfig)
model_config_kwargs: dict = field(default_factory=dict)from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr model
model: ModelConfig = field(default_factory=ModelConfig)attr sequence_parallel_size
sequence_parallel_size: int = 1attr fsdp_config
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)attr megatron_config
megatron_config: MegatronConfig = field(default_factory=MegatronConfig)attr model_config_kwargs
model_config_kwargs: dict = field(default_factory=dict)class AlgorithmConfig
AlgorithmConfig(advantage_estimator: str = 'grpo', kl_ctrl: KLCtrlConfig = KLCtrlConfig(), kl_estimator_type: str = 'k3', use_kl_in_reward: bool = False, use_kl_loss: bool = True, kl_loss_coef: float = 0.001, use_entropy_loss: bool = False, entropy_loss_coef: float = 0.01, temperature: Optional[float] = None, advantage_batch_normalize: bool = False, value_head_prefix: str = 'value_head', policy_loss_type: str = 'regular', loss_reduction: str = 'token_mean', grpo_norm_by_std: bool = True, zero_variance_filter: bool = False, lambd: float = 1.0, gamma: float = 1.0, eps_clip_low: float = 0.2, eps_clip_high: float = 0.2, clip_ratio_c: float = 3.0, tis_imp_ratio_cap: float = -1.0, use_tis: bool = False, off_policy_correction: OffPolicyCorrectionConfig = OffPolicyCorrectionConfig(), sapo: SAPOConfig = SAPOConfig(), value_clip: float = 0.2, dynamic_sampling: DynamicSamplingConfig = DynamicSamplingConfig(), clip_cov: ClipCovConfig = ClipCovConfig(), kl_cov: KLCovConfig = KLCovConfig(), cispo: CISPOConfig = CISPOConfig(), max_seq_len: Optional[int] = None) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
advantage_estimator | str | "grpo", "gae", "rloo", "reinforce++", or custom via AdvantageEstimatorRegistry. |
kl_ctrl | KLCtrlConfig | Only used when use_kl_in_reward=True (not applied when use_kl_loss=True). |
kl_estimator_type | str | "k1", "k2", "k3", "abs". See http://joschu.net/blog/kl-approx.html. |
use_kl_in_reward | bool | Apply KL penalty to rewards. Mutually exclusive with use_kl_loss. |
use_kl_loss | bool | Apply KL loss in the policy model. Mutually exclusive with use_kl_in_reward. |
kl_loss_coef | float | |
use_entropy_loss | bool | |
entropy_loss_coef | float | |
temperature | Optional[float] | Temperature for scaling logits in policy loss computation. |
advantage_batch_normalize | bool | |
value_head_prefix | str | |
policy_loss_type | str | "regular", "dual_clip", "gspo", "clip_cov", "kl_cov", or custom via PolicyLossRegistry. |
loss_reduction | str | "token_mean", "sequence_mean", or "seq_mean_token_sum_norm". |
grpo_norm_by_std | bool | |
zero_variance_filter | bool | Loss-mask prompts with zero-variance rewards. Only applicable when rewards are response-level. |
lambd | float | |
gamma | float | |
eps_clip_low | float | |
eps_clip_high | float | |
clip_ratio_c | float | Dual-clip parameter. |
tis_imp_ratio_cap | float | Deprecated: use off_policy_correction.tis_ratio_type="token" and token_tis_ratio_clip_high instead. |
use_tis | bool | Deprecated: use off_policy_correction instead. |
off_policy_correction | OffPolicyCorrectionConfig | |
sapo | SAPOConfig | |
value_clip | float | |
dynamic_sampling | DynamicSamplingConfig | |
clip_cov | ClipCovConfig | Only used when policy_loss_type="clip_cov". |
kl_cov | KLCovConfig | Only used when policy_loss_type="kl_cov". |
cispo | CISPOConfig | Only used when policy_loss_type="cispo". |
max_seq_len | Optional[int] | Used for seq_mean_token_sum_norm loss reduction; set explicitly for multi-turn. |
Source code in skyrl/train/config/config.py:322-375
@dataclass
class AlgorithmConfig(BaseConfig):
advantage_estimator: str = "grpo"
"""``"grpo"``, ``"gae"``, ``"rloo"``, ``"reinforce++"``, or custom via ``AdvantageEstimatorRegistry``."""
kl_ctrl: KLCtrlConfig = field(default_factory=KLCtrlConfig)
"""Only used when ``use_kl_in_reward=True`` (not applied when ``use_kl_loss=True``).
Uses ``kl_loss_coef`` as the initial KL coefficient."""
kl_estimator_type: str = "k3"
"""``"k1"``, ``"k2"``, ``"k3"``, ``"abs"``. See http://joschu.net/blog/kl-approx.html."""
use_kl_in_reward: bool = False
"""Apply KL penalty to rewards. Mutually exclusive with ``use_kl_loss``."""
use_kl_loss: bool = True
"""Apply KL loss in the policy model. Mutually exclusive with ``use_kl_in_reward``."""
kl_loss_coef: float = 0.001
use_entropy_loss: bool = False
entropy_loss_coef: float = 0.01
temperature: Optional[float] = None
"""Temperature for scaling logits in policy loss computation.
If ``None``, will be set to the temperature provided by ``generator.sampling_params.temperature`` during config validation.
NOTE: When using HTTP endpoints directly, make sure to set this value to the temperature used during generation
"""
advantage_batch_normalize: bool = False
value_head_prefix: str = "value_head"
policy_loss_type: str = "regular"
"""``"regular"``, ``"dual_clip"``, ``"gspo"``, ``"clip_cov"``, ``"kl_cov"``, or custom via ``PolicyLossRegistry``."""
loss_reduction: str = "token_mean"
"""``"token_mean"``, ``"sequence_mean"``, or ``"seq_mean_token_sum_norm"``."""
grpo_norm_by_std: bool = True
zero_variance_filter: bool = False
"""Loss-mask prompts with zero-variance rewards. Only applicable when rewards are response-level."""
lambd: float = 1.0
gamma: float = 1.0
eps_clip_low: float = 0.2
eps_clip_high: float = 0.2
clip_ratio_c: float = 3.0
"""Dual-clip parameter."""
tis_imp_ratio_cap: float = -1.0
"""Deprecated: use ``off_policy_correction.tis_ratio_type="token"`` and ``token_tis_ratio_clip_high`` instead."""
use_tis: bool = False
"""Deprecated: use ``off_policy_correction`` instead."""
off_policy_correction: OffPolicyCorrectionConfig = field(default_factory=OffPolicyCorrectionConfig)
sapo: SAPOConfig = field(default_factory=SAPOConfig)
value_clip: float = 0.2
dynamic_sampling: DynamicSamplingConfig = field(default_factory=DynamicSamplingConfig)
clip_cov: ClipCovConfig = field(default_factory=ClipCovConfig)
"""Only used when ``policy_loss_type="clip_cov"``."""
kl_cov: KLCovConfig = field(default_factory=KLCovConfig)
"""Only used when ``policy_loss_type="kl_cov"``."""
cispo: CISPOConfig = field(default_factory=CISPOConfig)
"""Only used when ``policy_loss_type="cispo"``."""
max_seq_len: Optional[int] = None
"""Used for ``seq_mean_token_sum_norm`` loss reduction; set explicitly for multi-turn.
If ``None``, calculated as ``generator.max_input_length + generator.sampling_params.max_generate_length``."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr advantage_estimator
advantage_estimator: str = 'grpo'"grpo", "gae", "rloo", "reinforce++", or custom via AdvantageEstimatorRegistry.
attr kl_ctrl
kl_ctrl: KLCtrlConfig = field(default_factory=KLCtrlConfig)Only used when use_kl_in_reward=True (not applied when use_kl_loss=True).
Uses kl_loss_coef as the initial KL coefficient.
attr kl_estimator_type
kl_estimator_type: str = 'k3'"k1", "k2", "k3", "abs". See http://joschu.net/blog/kl-approx.html.
attr use_kl_in_reward
use_kl_in_reward: bool = FalseApply KL penalty to rewards. Mutually exclusive with use_kl_loss.
attr use_kl_loss
use_kl_loss: bool = TrueApply KL loss in the policy model. Mutually exclusive with use_kl_in_reward.
attr kl_loss_coef
kl_loss_coef: float = 0.001attr use_entropy_loss
use_entropy_loss: bool = Falseattr entropy_loss_coef
entropy_loss_coef: float = 0.01attr temperature
temperature: Optional[float] = NoneTemperature for scaling logits in policy loss computation.
If None, will be set to the temperature provided by generator.sampling_params.temperature during config validation.
NOTE: When using HTTP endpoints directly, make sure to set this value to the temperature used during generation
attr advantage_batch_normalize
advantage_batch_normalize: bool = Falseattr value_head_prefix
value_head_prefix: str = 'value_head'attr policy_loss_type
policy_loss_type: str = 'regular'"regular", "dual_clip", "gspo", "clip_cov", "kl_cov", or custom via PolicyLossRegistry.
attr loss_reduction
loss_reduction: str = 'token_mean'"token_mean", "sequence_mean", or "seq_mean_token_sum_norm".
attr grpo_norm_by_std
grpo_norm_by_std: bool = Trueattr zero_variance_filter
zero_variance_filter: bool = FalseLoss-mask prompts with zero-variance rewards. Only applicable when rewards are response-level.
attr lambd
lambd: float = 1.0attr gamma
gamma: float = 1.0attr eps_clip_low
eps_clip_low: float = 0.2attr eps_clip_high
eps_clip_high: float = 0.2attr clip_ratio_c
clip_ratio_c: float = 3.0Dual-clip parameter.
attr tis_imp_ratio_cap
tis_imp_ratio_cap: float = -1.0Deprecated: use off_policy_correction.tis_ratio_type="token" and token_tis_ratio_clip_high instead.
attr use_tis
use_tis: bool = FalseDeprecated: use off_policy_correction instead.
attr off_policy_correction
off_policy_correction: OffPolicyCorrectionConfig = field(default_factory=OffPolicyCorrectionConfig)attr sapo
sapo: SAPOConfig = field(default_factory=SAPOConfig)attr value_clip
value_clip: float = 0.2attr dynamic_sampling
dynamic_sampling: DynamicSamplingConfig = field(default_factory=DynamicSamplingConfig)attr clip_cov
clip_cov: ClipCovConfig = field(default_factory=ClipCovConfig)Only used when policy_loss_type="clip_cov".
attr kl_cov
kl_cov: KLCovConfig = field(default_factory=KLCovConfig)Only used when policy_loss_type="kl_cov".
attr cispo
cispo: CISPOConfig = field(default_factory=CISPOConfig)Only used when policy_loss_type="cispo".
attr max_seq_len
max_seq_len: Optional[int] = NoneUsed for seq_mean_token_sum_norm loss reduction; set explicitly for multi-turn.
If None, calculated as generator.max_input_length + generator.sampling_params.max_generate_length.
class KLCtrlConfig
KLCtrlConfig(type: str = 'fixed', kl_target: float = 0.1, horizon: int = 10000) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
type | str | "fixed" or "adaptive". |
kl_target | float | Target KL divergence for the adaptive KL controller. |
horizon | int | Controls the update rate of the adaptive KL controller. |
Source code in skyrl/train/config/config.py:239-247
@dataclass
class KLCtrlConfig(BaseConfig):
type: str = "fixed"
"""``"fixed"`` or ``"adaptive"``."""
kl_target: float = 0.1
"""Target KL divergence for the adaptive KL controller."""
horizon: int = 10000
"""Controls the update rate of the adaptive KL controller."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr type
type: str = 'fixed'"fixed" or "adaptive".
attr kl_target
kl_target: float = 0.1Target KL divergence for the adaptive KL controller.
attr horizon
horizon: int = 10000Controls the update rate of the adaptive KL controller.
Algorithm Extensions
class SAPOConfig
SAPOConfig(tau_pos: float = 1.0, tau_neg: float = 1.05) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
Source code in skyrl/train/config/config.py:250-253
@dataclass
class SAPOConfig(BaseConfig):
tau_pos: float = 1.0
tau_neg: float = 1.05from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr tau_pos
tau_pos: float = 1.0attr tau_neg
tau_neg: float = 1.05class DynamicSamplingConfig
DynamicSamplingConfig(type: Optional[str] = None, max_sample_batches: int = 30, min_replace_ratio: float = 0.3) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
type | Optional[str] | "filter", "replace", or None. |
max_sample_batches | int | Sample at most this many batches before stopping. -1 to sample forever. |
min_replace_ratio | float | Minimum proportion of good samples to replace bad samples. Only used with "replace" strategy. |
Source code in skyrl/train/config/config.py:256-263
@dataclass
class DynamicSamplingConfig(BaseConfig):
type: Optional[str] = None
"""``"filter"``, ``"replace"``, or ``None``."""
max_sample_batches: int = 30
"""Sample at most this many batches before stopping. ``-1`` to sample forever."""
min_replace_ratio: float = 0.3
"""Minimum proportion of good samples to replace bad samples. Only used with ``"replace"`` strategy."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr type
type: Optional[str] = None"filter", "replace", or None.
attr max_sample_batches
max_sample_batches: int = 30Sample at most this many batches before stopping. -1 to sample forever.
attr min_replace_ratio
min_replace_ratio: float = 0.3Minimum proportion of good samples to replace bad samples. Only used with "replace" strategy.
class ClipCovConfig
ClipCovConfig(clip_ratio: float = 0.0002, clip_cov_lb: float = 1.0, clip_cov_ub: float = 5.0) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
clip_ratio | float | Fraction of tokens to clip based on covariance. |
clip_cov_lb | float | |
clip_cov_ub | float |
Source code in skyrl/train/config/config.py:266-272
@dataclass
class ClipCovConfig(BaseConfig):
clip_ratio: float = 0.0002
"""Fraction of tokens to clip based on covariance."""
clip_cov_lb: float = 1.0
clip_cov_ub: float = 5.0from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr clip_ratio
clip_ratio: float = 0.0002Fraction of tokens to clip based on covariance.
attr clip_cov_lb
clip_cov_lb: float = 1.0attr clip_cov_ub
clip_cov_ub: float = 5.0class KLCovConfig
KLCovConfig(kl_cov_frac: float = 0.2, ppo_kl_coef: float = 1.0) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
kl_cov_frac | float | Fraction of tokens to apply KL regularization to. |
ppo_kl_coef | float |
Source code in skyrl/train/config/config.py:275-280
@dataclass
class KLCovConfig(BaseConfig):
kl_cov_frac: float = 0.2
"""Fraction of tokens to apply KL regularization to."""
ppo_kl_coef: float = 1.0from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr kl_cov_frac
kl_cov_frac: float = 0.2Fraction of tokens to apply KL regularization to.
attr ppo_kl_coef
ppo_kl_coef: float = 1.0class CISPOConfig
CISPOConfig(cispo_eps_clip_low: float = 0.0, cispo_eps_clip_high: float = 5.0) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
cispo_eps_clip_low | float | Offset for lower bound of importance sampling ratio clipping (as opposed to PPO token update clipping). |
cispo_eps_clip_high | float | Offset for upper bound of importance sampling ratio clipping (as opposed to PPO token update clipping). |
Source code in skyrl/train/config/config.py:283-289
@dataclass
class CISPOConfig(BaseConfig):
cispo_eps_clip_low: float = 0.0
"""Offset for lower bound of importance sampling ratio clipping (as opposed to PPO token update clipping)."""
cispo_eps_clip_high: float = 5.0
"""Offset for upper bound of importance sampling ratio clipping (as opposed to PPO token update clipping)."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr cispo_eps_clip_low
cispo_eps_clip_low: float = 0.0Offset for lower bound of importance sampling ratio clipping (as opposed to PPO token update clipping).
attr cispo_eps_clip_high
cispo_eps_clip_high: float = 5.0Offset for upper bound of importance sampling ratio clipping (as opposed to PPO token update clipping).
class OffPolicyCorrectionConfig
OffPolicyCorrectionConfig(tis_ratio_type: Optional[str] = None, token_tis_ratio_clip_high: float = 2.0, sequence_tis_ratio_clip_high: float = 5.0, sequence_mask_metric: Optional[str] = None, geo_mask_high: float = 1.01, geo_mask_low: float = 0.99, product_mask_high: float = 2.0, product_mask_low: float = 0.5, outlier_token_is_threshold_low: Optional[float] = None, outlier_token_is_threshold_high: Optional[float] = None, token_mask_is_threshold_low: Optional[float] = None, token_mask_is_threshold_high: Optional[float] = None) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
tis_ratio_type | Optional[str] | Importance sampling ratio type for PPO loss correction: None, "token", or "sequence". |
token_tis_ratio_clip_high | float | Used when tis_ratio_type="token". Recommended range: 1.5--5.0. |
sequence_tis_ratio_clip_high | float | Used when tis_ratio_type="sequence". Recommended range: 2.0--10.0. |
sequence_mask_metric | Optional[str] | Method for masking sequences with cumulative IS ratios outside cap: None, "product", or "geometric". |
geo_mask_high | float | Used when sequence_mask_metric="geometric". Recommended ~0.99--1.01; MoE models may need a wider range. |
geo_mask_low | float | Used when sequence_mask_metric="geometric". |
product_mask_high | float | Used when sequence_mask_metric="product". Recommended ~0.5--2.0. |
product_mask_low | float | Used when sequence_mask_metric="product". |
outlier_token_is_threshold_low | Optional[float] | Set to mask sequences with any token IS ratio below this threshold. Suggested: 1e-4. None to disable. |
outlier_token_is_threshold_high | Optional[float] | Set to mask sequences with any token IS ratio above this threshold. Suggested: 100. None to disable. |
token_mask_is_threshold_low | Optional[float] | Set to mask per-token when IS ratio < token_mask_is_threshold_low. None to disable. |
token_mask_is_threshold_high | Optional[float] | Set to mask per-token when IS ratio > token_mask_is_threshold_high. None to disable. |
Source code in skyrl/train/config/config.py:293-319
@dataclass
class OffPolicyCorrectionConfig(BaseConfig):
tis_ratio_type: Optional[str] = None
"""Importance sampling ratio type for PPO loss correction: ``None``, ``"token"``, or ``"sequence"``.
The ratio is ``exp(logprobs_policy_old - logprobs_rollout_policy)``."""
token_tis_ratio_clip_high: float = 2.0
"""Used when ``tis_ratio_type="token"``. Recommended range: 1.5--5.0."""
sequence_tis_ratio_clip_high: float = 5.0
"""Used when ``tis_ratio_type="sequence"``. Recommended range: 2.0--10.0."""
sequence_mask_metric: Optional[str] = None
"""Method for masking sequences with cumulative IS ratios outside cap: ``None``, ``"product"``, or ``"geometric"``."""
geo_mask_high: float = 1.01
"""Used when ``sequence_mask_metric="geometric"``. Recommended ~0.99--1.01; MoE models may need a wider range."""
geo_mask_low: float = 0.99
"""Used when ``sequence_mask_metric="geometric"``."""
product_mask_high: float = 2.0
"""Used when ``sequence_mask_metric="product"``. Recommended ~0.5--2.0."""
product_mask_low: float = 0.5
"""Used when ``sequence_mask_metric="product"``."""
outlier_token_is_threshold_low: Optional[float] = None
"""Set to mask sequences with any token IS ratio below this threshold. Suggested: 1e-4. ``None`` to disable."""
outlier_token_is_threshold_high: Optional[float] = None
"""Set to mask sequences with any token IS ratio above this threshold. Suggested: 100. ``None`` to disable."""
token_mask_is_threshold_low: Optional[float] = None
"""Set to mask per-token when IS ratio < `token_mask_is_threshold_low`. ``None`` to disable."""
token_mask_is_threshold_high: Optional[float] = None
"""Set to mask per-token when IS ratio > `token_mask_is_threshold_high`. ``None`` to disable."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr tis_ratio_type
tis_ratio_type: Optional[str] = NoneImportance sampling ratio type for PPO loss correction: None, "token", or "sequence".
The ratio is exp(logprobs_policy_old - logprobs_rollout_policy).
attr token_tis_ratio_clip_high
token_tis_ratio_clip_high: float = 2.0Used when tis_ratio_type="token". Recommended range: 1.5--5.0.
attr sequence_tis_ratio_clip_high
sequence_tis_ratio_clip_high: float = 5.0Used when tis_ratio_type="sequence". Recommended range: 2.0--10.0.
attr sequence_mask_metric
sequence_mask_metric: Optional[str] = NoneMethod for masking sequences with cumulative IS ratios outside cap: None, "product", or "geometric".
attr geo_mask_high
geo_mask_high: float = 1.01Used when sequence_mask_metric="geometric". Recommended ~0.99--1.01; MoE models may need a wider range.
attr geo_mask_low
geo_mask_low: float = 0.99Used when sequence_mask_metric="geometric".
attr product_mask_high
product_mask_high: float = 2.0Used when sequence_mask_metric="product". Recommended ~0.5--2.0.
attr product_mask_low
product_mask_low: float = 0.5Used when sequence_mask_metric="product".
attr outlier_token_is_threshold_low
outlier_token_is_threshold_low: Optional[float] = NoneSet to mask sequences with any token IS ratio below this threshold. Suggested: 1e-4. None to disable.
attr outlier_token_is_threshold_high
outlier_token_is_threshold_high: Optional[float] = NoneSet to mask sequences with any token IS ratio above this threshold. Suggested: 100. None to disable.
attr token_mask_is_threshold_low
token_mask_is_threshold_low: Optional[float] = NoneSet to mask per-token when IS ratio < token_mask_is_threshold_low. None to disable.
attr token_mask_is_threshold_high
token_mask_is_threshold_high: Optional[float] = NoneSet to mask per-token when IS ratio > token_mask_is_threshold_high. None to disable.
class FullyAsyncConfig
FullyAsyncConfig(max_staleness_steps: int = 4, num_parallel_generation_workers: int = 768) -> NoneBases: BaseConfig
Knobs for fully async training. See https://docs.skyrl.ai/docs/tutorials/fully_async#step-2-config-knobs-to-tune-for-fully-async-training.
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
max_staleness_steps | int | Maximum off-policy steps allowed. If a trajectory group is scheduled at step i and trained at step j, |
num_parallel_generation_workers | int | Number of generation workers to spawn. Should be >= policy_mini_batch_size and |
Source code in skyrl/train/config/config.py:383-393
@dataclass
class FullyAsyncConfig(BaseConfig):
"""Knobs for fully async training.
See https://docs.skyrl.ai/docs/tutorials/fully_async#step-2-config-knobs-to-tune-for-fully-async-training."""
max_staleness_steps: int = 4
"""Maximum off-policy steps allowed. If a trajectory group is scheduled at step *i* and trained at step *j*,
then ``j - i <= max_staleness_steps``. Larger values increase throughput but also off-policy-ness."""
num_parallel_generation_workers: int = 768
"""Number of generation workers to spawn. Should be >= ``policy_mini_batch_size`` and
<= ``policy_mini_batch_size * (max_staleness_steps + 1)``."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr max_staleness_steps
max_staleness_steps: int = 4Maximum off-policy steps allowed. If a trajectory group is scheduled at step i and trained at step j,
then j - i <= max_staleness_steps. Larger values increase throughput but also off-policy-ness.
attr num_parallel_generation_workers
num_parallel_generation_workers: int = 768Number of generation workers to spawn. Should be >= policy_mini_batch_size and
<= policy_mini_batch_size * (max_staleness_steps + 1).
Inference & Generation
class SamplingParams
SamplingParams(max_generate_length: int = 1024, repetition_penalty: float = 1.0, temperature: float = 1.0, top_p: float = 1.0, min_p: float = 0.0, top_k: int = -1, logprobs: Optional[int] = 1, stop: Optional[List[str]] = None, additional_kwargs: Optional[Dict[str, Any]] = None) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
max_generate_length | int | |
repetition_penalty | float | |
temperature | float | |
top_p | float | |
min_p | float | |
top_k | int | |
logprobs | Optional[int] | |
stop | Optional[List[str]] | |
additional_kwargs | Optional[Dict[str, Any]] |
Source code in skyrl/train/config/config.py:401-411
@dataclass
class SamplingParams(BaseConfig):
max_generate_length: int = 1024
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
min_p: float = 0.0
top_k: int = -1
logprobs: Optional[int] = 1
stop: Optional[List[str]] = None
additional_kwargs: Optional[Dict[str, Any]] = Nonefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr max_generate_length
max_generate_length: int = 1024attr repetition_penalty
repetition_penalty: float = 1.0attr temperature
temperature: float = 1.0attr top_p
top_p: float = 1.0attr min_p
min_p: float = 0.0attr top_k
top_k: int = -1attr logprobs
logprobs: Optional[int] = 1attr stop
stop: Optional[List[str]] = Noneattr additional_kwargs
additional_kwargs: Optional[Dict[str, Any]] = Noneclass ChatTemplateConfig
ChatTemplateConfig(source: str = 'name', name_or_path: Optional[str] = None) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
source | str | |
name_or_path | Optional[str] |
Source code in skyrl/train/config/config.py:414-417
@dataclass
class ChatTemplateConfig(BaseConfig):
source: str = "name"
name_or_path: Optional[str] = Nonefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr source
source: str = 'name'attr name_or_path
name_or_path: Optional[str] = Noneclass InferenceEngineConfig
InferenceEngineConfig(model_dtype: str = 'bfloat16', run_engines_locally: bool = True, num_engines: int = 1, backend: str = 'vllm', weight_sync_backend: str = 'nccl', weight_transfer_threshold_cuda_ipc_GB: float = 1.0, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, expert_parallel_size: int = 1, data_parallel_size: int = 1, async_engine: bool = True, vllm_v1_disable_multiproc: bool = True, enable_prefix_caching: bool = True, enable_chunked_prefill: bool = True, max_num_batched_tokens: int = 8192, enforce_eager: bool = True, fully_sharded_loras: bool = False, enable_ray_prometheus_stats: bool = False, gpu_memory_utilization: float = 0.8, max_num_seqs: int = 1024, remote_urls: List[str] = (lambda: [])(), enable_http_endpoint: bool = False, http_endpoint_host: str = '127.0.0.1', http_endpoint_port: int = 8000, served_model_name: Optional[str] = None, distributed_executor_backend: str = 'ray', engine_init_kwargs: Dict[str, Any] = dict(), override_existing_update_group: str = 'auto', external_proxy_url: Optional[str] = None, external_server_urls: Optional[List[str]] = None) -> NoneBases: BaseConfig
Configuration for inference engine instantiation and management.
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
model_dtype | str | Should match the dtype used by the inference engine. |
run_engines_locally | bool | |
num_engines | int | |
backend | str | "vllm". |
weight_sync_backend | str | |
weight_transfer_threshold_cuda_ipc_GB | float | When using cuda_ipc, send weights in batches of this size (GB). |
tensor_parallel_size | int | |
pipeline_parallel_size | int | |
expert_parallel_size | int | |
data_parallel_size | int | |
async_engine | bool | |
vllm_v1_disable_multiproc | bool | Sets VLLM_ENABLE_V1_MULTIPROCESSING=0 for reproducibility. |
enable_prefix_caching | bool | |
enable_chunked_prefill | bool | |
max_num_batched_tokens | int | |
enforce_eager | bool | Disable CUDA graphs for stability. Set to False for higher performance, |
fully_sharded_loras | bool | |
enable_ray_prometheus_stats | bool | Enable Ray Prometheus stats logger for inference engine metrics (vLLM v1 only). |
gpu_memory_utilization | float | |
max_num_seqs | int | |
remote_urls | List[str] | |
enable_http_endpoint | bool | When True, launch an OpenAI-compatible HTTP endpoint for the inference engine client so that generators can send requests to this server instead of using .generate() Python calls. |
http_endpoint_host | str | |
http_endpoint_port | int | |
served_model_name | Optional[str] | Model name for HTTP endpoint validation. If set, must be used in the model field of |
distributed_executor_backend | str | Distributed executor backend for vLLM. Set to "ray" to use the Ray backend |
engine_init_kwargs | Dict[str, Any] | Pass-through kwargs for the vLLM engine. Names must match the engine's args. |
override_existing_update_group | str | "auto", "enable", or "disable". |
external_proxy_url | Optional[str] | Data-plane URL (load-balanced router) for the new inference layer. |
external_server_urls | Optional[List[str]] | Control-plane URLs (direct backend access) for the new inference layer. |
Source code in skyrl/train/config/config.py:425-478
@dataclass
class InferenceEngineConfig(BaseConfig):
"""Configuration for inference engine instantiation and management."""
model_dtype: str = "bfloat16"
"""Should match the dtype used by the inference engine."""
run_engines_locally: bool = True
num_engines: int = 1
backend: str = "vllm"
"""``"vllm"``."""
weight_sync_backend: str = "nccl"
weight_transfer_threshold_cuda_ipc_GB: float = 1.0
"""When using ``cuda_ipc``, send weights in batches of this size (GB)."""
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
expert_parallel_size: int = 1
data_parallel_size: int = 1
async_engine: bool = True
vllm_v1_disable_multiproc: bool = True
"""Sets ``VLLM_ENABLE_V1_MULTIPROCESSING=0`` for reproducibility."""
enable_prefix_caching: bool = True
enable_chunked_prefill: bool = True
max_num_batched_tokens: int = 8192
enforce_eager: bool = True
"""Disable CUDA graphs for stability. Set to ``False`` for higher performance,
but this may affect convergence for long-running or long-context training jobs."""
fully_sharded_loras: bool = False
enable_ray_prometheus_stats: bool = False
"""Enable Ray Prometheus stats logger for inference engine metrics (vLLM v1 only)."""
gpu_memory_utilization: float = 0.8
max_num_seqs: int = 1024
remote_urls: List[str] = field(default_factory=lambda: [])
enable_http_endpoint: bool = False
"""When ``True``, launch an OpenAI-compatible HTTP endpoint for the inference engine client so that generators can send requests to this server instead of using ``.generate()`` Python calls.
NOTE: When using HTTP endpoints directly, make sure to set ``trainer.algorithm.temperature`` to the temperature used during generation
"""
http_endpoint_host: str = "127.0.0.1"
http_endpoint_port: int = 8000
served_model_name: Optional[str] = None
"""Model name for HTTP endpoint validation. If set, must be used in the ``model`` field of
``/chat/completions`` requests instead of the model path. If ``None``, the model path is used."""
distributed_executor_backend: str = "ray"
"""Distributed executor backend for vLLM. Set to ``"ray"`` to use the Ray backend
or ``"mp"`` to use the multiprocessing backend (single-node serving only). Per-engine
placement groups are created when ``"mp"`` is used."""
engine_init_kwargs: Dict[str, Any] = field(default_factory=dict)
"""Pass-through kwargs for the vLLM engine. Names must match the engine's args."""
override_existing_update_group: str = "auto"
"""``"auto"``, ``"enable"``, or ``"disable"``."""
external_proxy_url: Optional[str] = None
"""Data-plane URL (load-balanced router) for the new inference layer."""
external_server_urls: Optional[List[str]] = None
"""Control-plane URLs (direct backend access) for the new inference layer."""from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr model_dtype
model_dtype: str = 'bfloat16'Should match the dtype used by the inference engine.
attr run_engines_locally
run_engines_locally: bool = Trueattr num_engines
num_engines: int = 1attr backend
backend: str = 'vllm'"vllm".
attr weight_sync_backend
weight_sync_backend: str = 'nccl'attr weight_transfer_threshold_cuda_ipc_GB
weight_transfer_threshold_cuda_ipc_GB: float = 1.0When using cuda_ipc, send weights in batches of this size (GB).
attr tensor_parallel_size
tensor_parallel_size: int = 1attr pipeline_parallel_size
pipeline_parallel_size: int = 1attr expert_parallel_size
expert_parallel_size: int = 1attr data_parallel_size
data_parallel_size: int = 1attr async_engine
async_engine: bool = Trueattr vllm_v1_disable_multiproc
vllm_v1_disable_multiproc: bool = TrueSets VLLM_ENABLE_V1_MULTIPROCESSING=0 for reproducibility.
attr enable_prefix_caching
enable_prefix_caching: bool = Trueattr enable_chunked_prefill
enable_chunked_prefill: bool = Trueattr max_num_batched_tokens
max_num_batched_tokens: int = 8192attr enforce_eager
enforce_eager: bool = TrueDisable CUDA graphs for stability. Set to False for higher performance,
but this may affect convergence for long-running or long-context training jobs.
attr fully_sharded_loras
fully_sharded_loras: bool = Falseattr enable_ray_prometheus_stats
enable_ray_prometheus_stats: bool = FalseEnable Ray Prometheus stats logger for inference engine metrics (vLLM v1 only).
attr gpu_memory_utilization
gpu_memory_utilization: float = 0.8attr max_num_seqs
max_num_seqs: int = 1024attr remote_urls
remote_urls: List[str] = field(default_factory=(lambda: []))attr enable_http_endpoint
enable_http_endpoint: bool = FalseWhen True, launch an OpenAI-compatible HTTP endpoint for the inference engine client so that generators can send requests to this server instead of using .generate() Python calls.
NOTE: When using HTTP endpoints directly, make sure to set trainer.algorithm.temperature to the temperature used during generation
attr http_endpoint_host
http_endpoint_host: str = '127.0.0.1'attr http_endpoint_port
http_endpoint_port: int = 8000attr served_model_name
served_model_name: Optional[str] = NoneModel name for HTTP endpoint validation. If set, must be used in the model field of
/chat/completions requests instead of the model path. If None, the model path is used.
attr distributed_executor_backend
distributed_executor_backend: str = 'ray'Distributed executor backend for vLLM. Set to "ray" to use the Ray backend
or "mp" to use the multiprocessing backend (single-node serving only). Per-engine
placement groups are created when "mp" is used.
attr engine_init_kwargs
engine_init_kwargs: Dict[str, Any] = field(default_factory=dict)Pass-through kwargs for the vLLM engine. Names must match the engine's args.
attr override_existing_update_group
override_existing_update_group: str = 'auto'"auto", "enable", or "disable".
attr external_proxy_url
external_proxy_url: Optional[str] = NoneData-plane URL (load-balanced router) for the new inference layer.
attr external_server_urls
external_server_urls: Optional[List[str]] = NoneControl-plane URLs (direct backend access) for the new inference layer.
class GeneratorConfig
GeneratorConfig(inference_engine: InferenceEngineConfig = InferenceEngineConfig(), n_samples_per_prompt: int = 5, batched: bool = False, max_turns: int = 1, max_input_length: Optional[int] = None, chat_template: ChatTemplateConfig = ChatTemplateConfig(), chat_template_kwargs: Dict[str, Any] = dict(), sampling_params: SamplingParams = SamplingParams(), use_conversation_multi_turn: bool = True, append_eos_token_after_stop_str_in_multi_turn: bool = True, eval_sampling_params: Optional[SamplingParams] = None, eval_n_samples_per_prompt: int = 1, zero_reward_on_non_stop: bool = False, apply_overlong_filtering: bool = False, rope_scaling: Optional[Dict[str, Any]] = None, rope_theta: Optional[float] = None, step_wise_trajectories: bool = False) -> NoneBases: BaseConfig
Configuration for generation behavior.
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
inference_engine | InferenceEngineConfig | |
n_samples_per_prompt | int | |
batched | bool | |
max_turns | int | |
max_input_length | Optional[int] | Max generator input length for multi-turn conversations. For single-turn, set equal to max_prompt_length. |
chat_template | ChatTemplateConfig | |
chat_template_kwargs | Dict[str, Any] | Kwargs passed to tokenizer.apply_chat_template. |
sampling_params | SamplingParams | |
use_conversation_multi_turn | bool | If True, each multi-turn model response and env observation is stored in a separate |
append_eos_token_after_stop_str_in_multi_turn | bool | When use_conversation_multi_turn=True and sampling_params.stop is set, append |
eval_sampling_params | Optional[SamplingParams] | Separate sampling params for evaluation. If None, then it defaults to SamplingParams(temperature=0.0, max_generate_length=generator.sampling_params.max_generate_length). |
eval_n_samples_per_prompt | int | |
zero_reward_on_non_stop | bool | Set reward to 0 when stop_reason is not "stop" (i.e., generation was truncated or aborted). |
apply_overlong_filtering | bool | Apply DAPO Overlong Filtering: mask out all tokens in the loss mask for trajectories that |
rope_scaling | Optional[Dict[str, Any]] | Can differ from the trainer's rope_scaling, useful for thinking models. |
rope_theta | Optional[float] | |
step_wise_trajectories | bool |
Source code in skyrl/train/config/config.py:486-524
@dataclass
class GeneratorConfig(BaseConfig):
"""Configuration for generation behavior."""
inference_engine: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
n_samples_per_prompt: int = 5
batched: bool = False
max_turns: int = 1
max_input_length: Optional[int] = None
"""Max generator input length for multi-turn conversations. For single-turn, set equal to ``max_prompt_length``."""
chat_template: ChatTemplateConfig = field(default_factory=ChatTemplateConfig)
chat_template_kwargs: Dict[str, Any] = field(default_factory=dict)
"""Kwargs passed to ``tokenizer.apply_chat_template``."""
sampling_params: SamplingParams = field(default_factory=SamplingParams)
use_conversation_multi_turn: bool = True
"""If ``True``, each multi-turn model response and env observation is stored in a separate
assistant/user message. If ``False``, they are appended to the original assistant response."""
append_eos_token_after_stop_str_in_multi_turn: bool = True
"""When ``use_conversation_multi_turn=True`` and ``sampling_params.stop`` is set, append
``eos_token_id`` to generations that end with a matched stop string."""
eval_sampling_params: Optional[SamplingParams] = None
"""Separate sampling params for evaluation. If ``None``, then it defaults to ``SamplingParams(temperature=0.0, max_generate_length=generator.sampling_params.max_generate_length)``."""
eval_n_samples_per_prompt: int = 1
zero_reward_on_non_stop: bool = False
"""Set reward to 0 when ``stop_reason`` is not ``"stop"`` (i.e., generation was truncated or aborted)."""
apply_overlong_filtering: bool = False
"""Apply DAPO Overlong Filtering: mask out all tokens in the loss mask for trajectories that
exceed max length (truncated, no EOS token)."""
rope_scaling: Optional[Dict[str, Any]] = None
"""Can differ from the trainer's ``rope_scaling``, useful for thinking models."""
rope_theta: Optional[float] = None
step_wise_trajectories: bool = False
def __post_init__(self):
if self.eval_sampling_params is None:
self.eval_sampling_params = SamplingParams(
temperature=0.0, max_generate_length=self.sampling_params.max_generate_length
)from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr inference_engine
inference_engine: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)attr n_samples_per_prompt
n_samples_per_prompt: int = 5attr batched
batched: bool = Falseattr max_turns
max_turns: int = 1attr max_input_length
max_input_length: Optional[int] = NoneMax generator input length for multi-turn conversations. For single-turn, set equal to max_prompt_length.
attr chat_template
chat_template: ChatTemplateConfig = field(default_factory=ChatTemplateConfig)attr chat_template_kwargs
chat_template_kwargs: Dict[str, Any] = field(default_factory=dict)Kwargs passed to tokenizer.apply_chat_template.
attr sampling_params
sampling_params: SamplingParams = field(default_factory=SamplingParams)attr use_conversation_multi_turn
use_conversation_multi_turn: bool = TrueIf True, each multi-turn model response and env observation is stored in a separate
assistant/user message. If False, they are appended to the original assistant response.
attr append_eos_token_after_stop_str_in_multi_turn
append_eos_token_after_stop_str_in_multi_turn: bool = TrueWhen use_conversation_multi_turn=True and sampling_params.stop is set, append
eos_token_id to generations that end with a matched stop string.
attr eval_sampling_params
eval_sampling_params: Optional[SamplingParams] = NoneSeparate sampling params for evaluation. If None, then it defaults to SamplingParams(temperature=0.0, max_generate_length=generator.sampling_params.max_generate_length).
attr eval_n_samples_per_prompt
eval_n_samples_per_prompt: int = 1attr zero_reward_on_non_stop
zero_reward_on_non_stop: bool = FalseSet reward to 0 when stop_reason is not "stop" (i.e., generation was truncated or aborted).
attr apply_overlong_filtering
apply_overlong_filtering: bool = FalseApply DAPO Overlong Filtering: mask out all tokens in the loss mask for trajectories that exceed max length (truncated, no EOS token).
attr rope_scaling
rope_scaling: Optional[Dict[str, Any]] = NoneCan differ from the trainer's rope_scaling, useful for thinking models.
attr rope_theta
rope_theta: Optional[float] = Noneattr step_wise_trajectories
step_wise_trajectories: bool = FalseEnvironment
class EnvironmentConfig
EnvironmentConfig(env_class: str = 'gsm8k', skyrl_gym: SkyRLGymConfig = SkyRLGymConfig()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
env_class | str | |
skyrl_gym | SkyRLGymConfig |
Source code in skyrl/train/config/config.py:547-550
@dataclass
class EnvironmentConfig(BaseConfig):
env_class: str = "gsm8k"
skyrl_gym: SkyRLGymConfig = field(default_factory=SkyRLGymConfig)from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr env_class
env_class: str = 'gsm8k'attr skyrl_gym
skyrl_gym: SkyRLGymConfig = field(default_factory=SkyRLGymConfig)class SkyRLGymConfig
SkyRLGymConfig(max_env_workers: int = 32, text2sql: Text2SQLEnvConfig = Text2SQLEnvConfig(), llm_as_a_judge: GSM8kLLMJudgeEnvConfig = GSM8kLLMJudgeEnvConfig(), search: SearchEnvConfig = SearchEnvConfig()) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
max_env_workers | int | |
text2sql | Text2SQLEnvConfig | |
llm_as_a_judge | GSM8kLLMJudgeEnvConfig | |
search | SearchEnvConfig |
Source code in skyrl/train/config/config.py:539-544
@dataclass
class SkyRLGymConfig(BaseConfig):
max_env_workers: int = 32
text2sql: Text2SQLEnvConfig = field(default_factory=Text2SQLEnvConfig)
llm_as_a_judge: GSM8kLLMJudgeEnvConfig = field(default_factory=GSM8kLLMJudgeEnvConfig)
search: SearchEnvConfig = field(default_factory=SearchEnvConfig)from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr max_env_workers
max_env_workers: int = 32attr text2sql
text2sql: Text2SQLEnvConfig = field(default_factory=Text2SQLEnvConfig)attr llm_as_a_judge
llm_as_a_judge: GSM8kLLMJudgeEnvConfig = field(default_factory=GSM8kLLMJudgeEnvConfig)attr search
search: SearchEnvConfig = field(default_factory=SearchEnvConfig)class GSM8kLLMJudgeEnvConfig
GSM8kLLMJudgeEnvConfig(model: str = 'gpt-4o-mini', base_url: Optional[str] = None) -> NoneBases: BaseConfig
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
Source code in skyrl/train/config/config.py:533-536
@dataclass
class GSM8kLLMJudgeEnvConfig(BaseConfig):
model: str = "gpt-4o-mini"
base_url: Optional[str] = Nonefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr model
model: str = 'gpt-4o-mini'attr base_url
base_url: Optional[str] = None