Entrypoint
Entrypoint API — BasePPOExp, EvalOnlyEntrypoint.
Training Entrypoint
The main entrypoint is the BasePPOExp class which runs the main training loop.
class BasePPOExp
BasePPOExp(cfg: SkyRLTrainConfig)Functions:
| Name | Description |
|---|---|
get_cfg_as_str | |
get_train_dataset | Initializes the training dataset. |
get_eval_dataset | Initializes the evaluation dataset. |
get_colocate_pg | Initializes a placement group for colocated training. |
get_generator | Initializes the generator. |
get_trainer | Initializes the trainer. |
get_tracker | Initializes the tracker for experiment tracking. |
get_inference_client | Setup and return the inference engine client. |
run |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
tokenizer | ||
train_dataset | ||
eval_dataset | ||
colocate_pg |
Initializes a PPO experiment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg | SkyRLTrainConfig | The fully resolved SkyRLTrainConfig instance. | required |
Source code in skyrl/train/entrypoints/main_base.py:133-414
class BasePPOExp:
def __init__(self, cfg: SkyRLTrainConfig):
"""
Initializes a PPO experiment.
Args:
cfg: The fully resolved SkyRLTrainConfig instance.
"""
self.cfg = cfg
self.tokenizer = get_tokenizer(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
padding_side="left",
)
self.train_dataset = self.get_train_dataset()
self.eval_dataset = self.get_eval_dataset()
self.colocate_pg = self.get_colocate_pg()
# New inference resources (created lazily when _SKYRL_USE_NEW_INFERENCE=1)
self._server_groups = None
self._prefill_server_groups = None
self._decode_server_groups = None
self._inference_router = None
@staticmethod
def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
return get_config_as_yaml_str(cfg)
def get_train_dataset(self):
"""Initializes the training dataset.
Returns:
PromptDataset: The training dataset.
"""
prompts_dataset = PromptDataset(
datasets=self.cfg.data.train_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
# make sure the dataset is large enough to train on
assert (
len(prompts_dataset) >= self.cfg.trainer.train_batch_size
), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
return prompts_dataset
def get_eval_dataset(self):
"""Initializes the evaluation dataset.
Returns:
PromptDataset: The evaluation dataset.
"""
if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
prompts_dataset = PromptDataset(
datasets=self.cfg.data.val_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
return prompts_dataset
return None
def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]:
"""Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference
engines. The returned wrapper computes GPU-aware bundle ordering at init time.
Args:
timeout (int): The timeout for the placement group to be ready.
Returns:
ResolvedPlacementGroup: The placement group wrapper for colocated training, or None.
"""
if not self.cfg.trainer.placement.colocate_all:
return None
ie_cfg = self.cfg.generator.inference_engine
per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count
pg = placement_group(
[{"GPU": 1, "CPU": 1}] * total_gpu_slots,
strategy="PACK",
)
get_ray_pg_ready_with_timeout(pg, timeout=timeout)
return ResolvedPlacementGroup(pg)
def get_generator(self, cfg, tokenizer, inference_engine_client):
"""Initializes the generator.
Returns:
GeneratorInterface: The generator.
"""
if cfg.generator.vision_language_generator:
from skyrl.train.generators.skyrl_vlm_generator import SkyRLVLMGymGenerator
generator_cls = SkyRLVLMGymGenerator
else:
from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator
generator_cls = SkyRLGymGenerator
return generator_cls(
generator_cfg=cfg.generator,
skyrl_gym_cfg=cfg.environment.skyrl_gym,
inference_engine_client=inference_engine_client,
tokenizer=tokenizer,
policy_model_name=resolve_policy_model_name(cfg),
)
def get_trainer(
self,
cfg,
tracker,
tokenizer,
train_dataset,
eval_dataset,
inference_engine_client,
generator: GeneratorInterface,
colocate_pg,
):
"""Initializes the trainer.
Returns:
RayPPOTrainer: The trainer.
"""
return RayPPOTrainer(
cfg=cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=colocate_pg,
)
def get_tracker(self):
"""Initializes the tracker for experiment tracking.
Returns:
Tracking: The tracker.
"""
return Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
config=self.cfg,
tags=self.cfg.trainer.tags,
)
def get_inference_client(self) -> InferenceEngineInterface:
"""Setup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize
inference engine creation (e.g., FlashRL, custom backends).
Returns:
InferenceEngineInterface: The inference engine client.
"""
if _SKYRL_USE_NEW_INFERENCE:
logger.info("Initializing new inference client")
return self._get_new_inference_client()
else:
return self._get_legacy_inference_client()
def _get_legacy_inference_client(self) -> InferenceEngineInterface:
"""Legacy inference client using Ray actors."""
if self.cfg.generator.inference_engine.run_engines_locally:
inference_engines = create_ray_wrapped_inference_engines_from_config(
self.cfg, self.colocate_pg, self.tokenizer
)
else:
inference_engines = create_remote_inference_engines_from_config(self.cfg, self.tokenizer)
return InferenceEngineClient(
inference_engines,
self.tokenizer,
self.cfg.trainer.policy.model.path,
self.cfg.trainer.policy.model.lora,
self.cfg.generator.inference_engine,
)
def _get_new_inference_client(self):
"""New inference client using HTTP endpoints.
Returns:
RemoteInferenceClient: The new inference client.
"""
from skyrl.backends.skyrl_train.inference_servers.setup import (
build_new_inference_client,
)
is_colocated = self.cfg.trainer.placement.colocate_all
client, server_setup = build_new_inference_client(
self.cfg,
self.tokenizer,
placement_group=self.colocate_pg if is_colocated else None,
)
self._inference_router = server_setup.router
self._server_groups = server_setup.server_groups
self._prefill_server_groups = server_setup.prefill_server_groups
self._decode_server_groups = server_setup.decode_server_groups
if is_colocated:
# Callers must invoke get_inference_client() from a sync context (no running event loop).
asyncio.run(client.sleep())
logger.info("HTTP Inference: Colocated mode - slept inference engines after startup")
return client
def _setup_trainer(self):
"""Setup and return the trainer.
Instantiates the trainer and all the associated models for training.
Returns:
RayPPOTrainer: The trainer.
"""
logger.info(self.get_cfg_as_str(self.cfg))
os.makedirs(self.cfg.trainer.export_path, exist_ok=True)
os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True)
if self.cfg.trainer.strategy == "fsdp":
from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import (
CriticWorker,
PolicyWorker,
RefWorker,
)
elif self.cfg.trainer.strategy == "megatron":
from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
CriticWorker,
PolicyWorker,
RefWorker,
)
else:
raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}")
# NOTE (sumanthrh): Instantiate tracker before trainer init.
# We have custom validation before this step to give better error messages.
tracker = self.get_tracker()
inference_engine_client = self.get_inference_client()
generator: GeneratorInterface = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)
trainer = self.get_trainer(
cfg=self.cfg,
tracker=tracker,
tokenizer=self.tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=self.colocate_pg,
)
# Expose the trainer on self so callers can log exceptions raised
# during `build_models` (which happens before _setup_trainer returns).
self.trainer = trainer
# Build the models
trainer.build_models(PolicyWorker, CriticWorker, RefWorker)
return trainer
def run(self):
self.trainer = None
try:
trainer = self._setup_trainer()
# Start the training loop
asyncio.run(trainer.train())
except Exception as e:
# OOMs raised inside actor init (e.g. FSDPPolicyWorkerBase.init_model)
# surface here as RayTaskError. Without this they only land in Ray
# worker logs; route them through the tracker so wandb users see
# them as an `error/tracebacks` table row.
if self.trainer is not None and self.trainer.tracker is not None:
self.trainer.tracker.log_exception(e, step=self.trainer.global_step)
else:
logger.error(f"Setup failed before tracker was initialized:\n{e}")
raiseattr cfg
cfg = cfgattr tokenizer
tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')attr train_dataset
train_dataset = self.get_train_dataset()attr eval_dataset
eval_dataset = self.get_eval_dataset()attr colocate_pg
colocate_pg = self.get_colocate_pg()method staticmethod get_cfg_as_str
get_cfg_as_str(cfg: SkyRLTrainConfig) -> strSource code in skyrl/train/entrypoints/main_base.py:158-160
@staticmethod
def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
return get_config_as_yaml_str(cfg)method get_train_dataset
get_train_dataset()Initializes the training dataset.
Returns:
| Name | Type | Description |
|---|---|---|
PromptDataset | The training dataset. |
Source code in skyrl/train/entrypoints/main_base.py:162-178
def get_train_dataset(self):
"""Initializes the training dataset.
Returns:
PromptDataset: The training dataset.
"""
prompts_dataset = PromptDataset(
datasets=self.cfg.data.train_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
# make sure the dataset is large enough to train on
assert (
len(prompts_dataset) >= self.cfg.trainer.train_batch_size
), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
return prompts_datasetmethod get_eval_dataset
get_eval_dataset()Initializes the evaluation dataset.
Returns:
| Name | Type | Description |
|---|---|---|
PromptDataset | The evaluation dataset. |
Source code in skyrl/train/entrypoints/main_base.py:180-194
def get_eval_dataset(self):
"""Initializes the evaluation dataset.
Returns:
PromptDataset: The evaluation dataset.
"""
if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
prompts_dataset = PromptDataset(
datasets=self.cfg.data.val_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
return prompts_dataset
return Nonemethod get_colocate_pg
get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference engines. The returned wrapper computes GPU-aware bundle ordering at init time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timeout | int | The timeout for the placement group to be ready. | SKYRL_RAY_PG_TIMEOUT_IN_S |
Returns:
| Name | Type | Description |
|---|---|---|
ResolvedPlacementGroup | Optional[ResolvedPlacementGroup] | The placement group wrapper for colocated training, or None. |
Source code in skyrl/train/entrypoints/main_base.py:196-220
def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]:
"""Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference
engines. The returned wrapper computes GPU-aware bundle ordering at init time.
Args:
timeout (int): The timeout for the placement group to be ready.
Returns:
ResolvedPlacementGroup: The placement group wrapper for colocated training, or None.
"""
if not self.cfg.trainer.placement.colocate_all:
return None
ie_cfg = self.cfg.generator.inference_engine
per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count
pg = placement_group(
[{"GPU": 1, "CPU": 1}] * total_gpu_slots,
strategy="PACK",
)
get_ray_pg_ready_with_timeout(pg, timeout=timeout)
return ResolvedPlacementGroup(pg)method get_generator
get_generator(cfg, tokenizer, inference_engine_client)Initializes the generator.
Returns:
| Name | Type | Description |
|---|---|---|
GeneratorInterface | The generator. |
Source code in skyrl/train/entrypoints/main_base.py:222-243
def get_generator(self, cfg, tokenizer, inference_engine_client):
"""Initializes the generator.
Returns:
GeneratorInterface: The generator.
"""
if cfg.generator.vision_language_generator:
from skyrl.train.generators.skyrl_vlm_generator import SkyRLVLMGymGenerator
generator_cls = SkyRLVLMGymGenerator
else:
from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator
generator_cls = SkyRLGymGenerator
return generator_cls(
generator_cfg=cfg.generator,
skyrl_gym_cfg=cfg.environment.skyrl_gym,
inference_engine_client=inference_engine_client,
tokenizer=tokenizer,
policy_model_name=resolve_policy_model_name(cfg),
)method get_trainer
get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)Initializes the trainer.
Returns:
| Name | Type | Description |
|---|---|---|
RayPPOTrainer | The trainer. |
Source code in skyrl/train/entrypoints/main_base.py:245-270
def get_trainer(
self,
cfg,
tracker,
tokenizer,
train_dataset,
eval_dataset,
inference_engine_client,
generator: GeneratorInterface,
colocate_pg,
):
"""Initializes the trainer.
Returns:
RayPPOTrainer: The trainer.
"""
return RayPPOTrainer(
cfg=cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=colocate_pg,
)method get_tracker
get_tracker()Initializes the tracker for experiment tracking.
Returns:
| Name | Type | Description |
|---|---|---|
Tracking | The tracker. |
Source code in skyrl/train/entrypoints/main_base.py:272-284
def get_tracker(self):
"""Initializes the tracker for experiment tracking.
Returns:
Tracking: The tracker.
"""
return Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
config=self.cfg,
tags=self.cfg.trainer.tags,
)method get_inference_client
get_inference_client() -> InferenceEngineInterfaceSetup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).
Returns:
| Name | Type | Description |
|---|---|---|
InferenceEngineInterface | InferenceEngineInterface | The inference engine client. |
Source code in skyrl/train/entrypoints/main_base.py:286-299
def get_inference_client(self) -> InferenceEngineInterface:
"""Setup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize
inference engine creation (e.g., FlashRL, custom backends).
Returns:
InferenceEngineInterface: The inference engine client.
"""
if _SKYRL_USE_NEW_INFERENCE:
logger.info("Initializing new inference client")
return self._get_new_inference_client()
else:
return self._get_legacy_inference_client()method run
run()Source code in skyrl/train/entrypoints/main_base.py:399-414
def run(self):
self.trainer = None
try:
trainer = self._setup_trainer()
# Start the training loop
asyncio.run(trainer.train())
except Exception as e:
# OOMs raised inside actor init (e.g. FSDPPolicyWorkerBase.init_model)
# surface here as RayTaskError. Without this they only land in Ray
# worker logs; route them through the tracker so wandb users see
# them as an `error/tracebacks` table row.
if self.trainer is not None and self.trainer.tracker is not None:
self.trainer.tracker.log_exception(e, step=self.trainer.global_step)
else:
logger.error(f"Setup failed before tracker was initialized:\n{e}")
raiseEvaluation Entrypoint
The evaluation-only entrypoint is the EvalOnlyEntrypoint class which runs evaluation without training.
class EvalOnlyEntrypoint
Bases: BasePPOExp
Functions:
| Name | Description |
|---|---|
get_train_dataset | Override to avoid requiring a train dataset for eval-only runs. |
run | |
get_cfg_as_str | |
get_eval_dataset | Initializes the evaluation dataset. |
get_colocate_pg | Initializes a placement group for colocated training. |
get_generator | Initializes the generator. |
get_trainer | Initializes the trainer. |
get_tracker | Initializes the tracker for experiment tracking. |
get_inference_client | Setup and return the inference engine client. |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
tokenizer | ||
train_dataset | ||
eval_dataset | ||
colocate_pg |
Source code in skyrl/train/entrypoints/main_generate.py:22-45
class EvalOnlyEntrypoint(BasePPOExp):
def get_train_dataset(self):
"""Override to avoid requiring a train dataset for eval-only runs."""
return None
async def run(self, inference_engine_client: InferenceEngineInterface) -> dict[str, Any]:
assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"
await inference_engine_client.wake_up()
generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)
eval_fn = evaluate_step_wise if self.cfg.generator.step_wise_trajectories else evaluate
results: dict[str, Any] = await eval_fn(
eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
generator=generator,
cfg=self.cfg,
global_step=None,
tokenizer=self.tokenizer,
)
tracker = self.get_tracker()
tracker.log(results, step=0, commit=True)
return resultsmethod get_train_dataset
get_train_dataset()Override to avoid requiring a train dataset for eval-only runs.
Source code in skyrl/train/entrypoints/main_generate.py:23-25
def get_train_dataset(self):
"""Override to avoid requiring a train dataset for eval-only runs."""
return Nonemethod run
run(inference_engine_client: InferenceEngineInterface) -> dict[str, Any]Source code in skyrl/train/entrypoints/main_generate.py:27-45
async def run(self, inference_engine_client: InferenceEngineInterface) -> dict[str, Any]:
assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"
await inference_engine_client.wake_up()
generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)
eval_fn = evaluate_step_wise if self.cfg.generator.step_wise_trajectories else evaluate
results: dict[str, Any] = await eval_fn(
eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
generator=generator,
cfg=self.cfg,
global_step=None,
tokenizer=self.tokenizer,
)
tracker = self.get_tracker()
tracker.log(results, step=0, commit=True)
return resultsattr cfg
cfg = cfgattr tokenizer
tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')attr train_dataset
train_dataset = self.get_train_dataset()attr eval_dataset
eval_dataset = self.get_eval_dataset()attr colocate_pg
colocate_pg = self.get_colocate_pg()method staticmethod get_cfg_as_str
get_cfg_as_str(cfg: SkyRLTrainConfig) -> strmethod get_eval_dataset
get_eval_dataset()Initializes the evaluation dataset.
Returns:
| Name | Type | Description |
|---|---|---|
PromptDataset | The evaluation dataset. |
method get_colocate_pg
get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference engines. The returned wrapper computes GPU-aware bundle ordering at init time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timeout | int | The timeout for the placement group to be ready. | SKYRL_RAY_PG_TIMEOUT_IN_S |
Returns:
| Name | Type | Description |
|---|---|---|
ResolvedPlacementGroup | Optional[ResolvedPlacementGroup] | The placement group wrapper for colocated training, or None. |
method get_generator
get_generator(cfg, tokenizer, inference_engine_client)Initializes the generator.
Returns:
| Name | Type | Description |
|---|---|---|
GeneratorInterface | The generator. |
method get_trainer
get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)Initializes the trainer.
Returns:
| Name | Type | Description |
|---|---|---|
RayPPOTrainer | The trainer. |
method get_tracker
get_tracker()Initializes the tracker for experiment tracking.
Returns:
| Name | Type | Description |
|---|---|---|
Tracking | The tracker. |
method get_inference_client
get_inference_client() -> InferenceEngineInterfaceSetup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).
Returns:
| Name | Type | Description |
|---|---|---|
InferenceEngineInterface | InferenceEngineInterface | The inference engine client. |