Entrypoint
Entrypoint API — BasePPOExp, EvalOnlyEntrypoint.
Training Entrypoint
The main entrypoint is the BasePPOExp class which runs the main training loop.
class BasePPOExp
BasePPOExp(cfg: SkyRLTrainConfig)Functions:
| Name | Description |
|---|---|
get_cfg_as_str | |
get_train_dataset | Initializes the training dataset. |
get_eval_dataset | Initializes the evaluation dataset. |
get_colocate_pg | Initializes a placement group for colocated training. |
get_generator | Initializes the generator. |
get_trainer | Initializes the trainer. |
get_tracker | Initializes the tracker for experiment tracking. |
get_inference_client | Setup and return the inference engine client. |
run |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
tokenizer | ||
train_dataset | ||
eval_dataset | ||
colocate_pg |
Initializes a PPO experiment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg | SkyRLTrainConfig | The fully resolved SkyRLTrainConfig instance. | required |
Source code in skyrl/train/entrypoints/main_base.py:124-433
class BasePPOExp:
def __init__(self, cfg: SkyRLTrainConfig):
"""
Initializes a PPO experiment.
Args:
cfg: The fully resolved SkyRLTrainConfig instance.
"""
self.cfg = cfg
self.tokenizer = get_tokenizer(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
padding_side="left",
)
self.train_dataset = self.get_train_dataset()
self.eval_dataset = self.get_eval_dataset()
self.colocate_pg = self.get_colocate_pg()
# New inference resources (created lazily when _SKYRL_USE_NEW_INFERENCE=1)
self._server_group = None
self._inference_router = None
@staticmethod
def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
return get_config_as_yaml_str(cfg)
def get_train_dataset(self):
"""Initializes the training dataset.
Returns:
PromptDataset: The training dataset.
"""
prompts_dataset = PromptDataset(
datasets=self.cfg.data.train_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
# make sure the dataset is large enough to train on
assert (
len(prompts_dataset) >= self.cfg.trainer.train_batch_size
), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
return prompts_dataset
def get_eval_dataset(self):
"""Initializes the evaluation dataset.
Returns:
PromptDataset: The evaluation dataset.
"""
if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
prompts_dataset = PromptDataset(
datasets=self.cfg.data.val_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
return prompts_dataset
return None
def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]:
"""Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference
engines.
Args:
timeout (int): The timeout for the placement group to be ready.
Returns:
A PlacementGroup when colocate_all is True, else None.
"""
if not self.cfg.trainer.placement.colocate_all:
return None
ie_cfg = self.cfg.generator.inference_engine
per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count
pg = placement_group(
[{"GPU": 1, "CPU": 1}] * total_gpu_slots,
strategy="PACK",
)
get_ray_pg_ready_with_timeout(pg, timeout=timeout)
return pg
def get_generator(self, cfg, tokenizer, inference_engine_client):
"""Initializes the generator.
Returns:
GeneratorInterface: The generator.
"""
from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator
return SkyRLGymGenerator(
generator_cfg=cfg.generator,
skyrl_gym_cfg=cfg.environment.skyrl_gym,
inference_engine_client=inference_engine_client,
tokenizer=tokenizer,
)
def get_trainer(
self,
cfg,
tracker,
tokenizer,
train_dataset,
eval_dataset,
inference_engine_client,
generator: GeneratorInterface,
colocate_pg,
):
"""Initializes the trainer.
Returns:
RayPPOTrainer: The trainer.
"""
return RayPPOTrainer(
cfg=cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=colocate_pg,
)
def get_tracker(self):
"""Initializes the tracker for experiment tracking.
Returns:
Tracking: The tracker.
"""
return Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
config=self.cfg,
)
def get_inference_client(self) -> InferenceEngineInterface:
"""Setup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize
inference engine creation (e.g., FlashRL, custom backends).
Returns:
InferenceEngineInterface: The inference engine client.
"""
if _SKYRL_USE_NEW_INFERENCE:
logger.info("Initializing new inference client")
return self._get_new_inference_client()
else:
return self._get_legacy_inference_client()
def _get_legacy_inference_client(self) -> InferenceEngineInterface:
"""Legacy inference client using Ray actors."""
if self.cfg.generator.inference_engine.run_engines_locally:
inference_engines = create_ray_wrapped_inference_engines_from_config(
self.cfg, self.colocate_pg, self.tokenizer
)
else:
inference_engines = create_remote_inference_engines_from_config(self.cfg, self.tokenizer)
return InferenceEngineClient(
inference_engines,
self.tokenizer,
self.cfg.trainer.policy.model.path,
self.cfg.trainer.policy.model.lora,
self.cfg.generator.inference_engine,
)
def _get_new_inference_client(self):
"""New inference client using HTTP endpoints.
Config combinations:
- Colocated + external URLs → ERROR (validated earlier)
- Neither set → Build servers internally
- external_server_urls only → Create router over external servers
- external_proxy_url only → Use proxy for both data + control plane
- Both set → Fully external (proxy for data plane, servers for control plane)
Returns:
RemoteInferenceClient: The new inference client.
"""
from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import (
RemoteInferenceClient,
)
from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter
from skyrl.backends.skyrl_train.inference_servers.server_group import (
ServerGroup,
)
ie_cfg = self.cfg.generator.inference_engine
is_colocated = self.cfg.trainer.placement.colocate_all
external_proxy_url = ie_cfg.external_proxy_url
external_server_urls = ie_cfg.external_server_urls
has_external_proxy = external_proxy_url is not None
has_external_servers = external_server_urls is not None
if has_external_proxy and has_external_servers:
# Case: Both external - fully external setup
proxy_url = external_proxy_url
server_urls = list(external_server_urls)
logger.info(
f"HTTP Inference: Using fully external setup - " f"proxy_url={proxy_url}, server_urls={server_urls}"
)
elif has_external_proxy and not has_external_servers:
# Case: Proxy only - assume proxy handles control plane too
proxy_url = external_proxy_url
server_urls = [proxy_url]
logger.info(
f"HTTP Inference: Using external proxy for both data and " f"control plane - proxy_url={proxy_url}"
)
elif has_external_servers and not has_external_proxy:
# Case: Servers only - create internal router over them
server_urls = list(external_server_urls)
self._inference_router = InferenceRouter(server_urls=server_urls)
proxy_url = self._inference_router.start()
logger.info(
f"HTTP Inference: Created internal router over external "
f"servers - server_urls={server_urls}, proxy_url={proxy_url}"
)
else:
# Case: Neither - build servers and router internally
cli_args = build_vllm_cli_args(self.cfg)
self._server_group = ServerGroup(
cli_args=cli_args,
num_servers=ie_cfg.num_engines,
placement_group=self.colocate_pg if is_colocated else None,
enable_dp=ie_cfg.data_parallel_size > 1,
distributed_executor_backend=ie_cfg.distributed_executor_backend,
)
server_infos = self._server_group.start()
server_urls = [info.url for info in server_infos]
self._inference_router = InferenceRouter(server_urls=server_urls)
proxy_url = self._inference_router.start()
logger.info(
f"HTTP Inference: Built servers and router internally - "
f"proxy_url={proxy_url}, server_urls={server_urls}, colocated={is_colocated}"
)
return RemoteInferenceClient(
proxy_url=proxy_url,
server_urls=server_urls,
model_name=self.cfg.trainer.policy.model.path,
)
def _setup_trainer(self):
"""Setup and return the trainer.
Instantiates the trainer and all the associated models for training.
Returns:
RayPPOTrainer: The trainer.
"""
logger.info(self.get_cfg_as_str(self.cfg))
os.makedirs(self.cfg.trainer.export_path, exist_ok=True)
os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True)
if self.cfg.trainer.strategy in ("fsdp", "fsdp2"):
from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import (
CriticWorker,
PolicyWorker,
RefWorker,
)
elif self.cfg.trainer.strategy == "megatron":
from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
CriticWorker,
PolicyWorker,
RefWorker,
)
else:
raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}")
# NOTE (sumanthrh): Instantiate tracker before trainer init.
# We have custom validation before this step to give better error messages.
tracker = self.get_tracker()
inference_engine_client = self.get_inference_client()
generator: GeneratorInterface = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)
trainer = self.get_trainer(
cfg=self.cfg,
tracker=tracker,
tokenizer=self.tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=self.colocate_pg,
)
# Build the models
trainer.build_models(PolicyWorker, CriticWorker, RefWorker)
return trainer
def run(self):
trainer = self._setup_trainer()
# Start the training loop
asyncio.run(trainer.train())attr cfg
cfg = cfgattr tokenizer
tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')attr train_dataset
train_dataset = self.get_train_dataset()attr eval_dataset
eval_dataset = self.get_eval_dataset()attr colocate_pg
colocate_pg = self.get_colocate_pg()method staticmethod get_cfg_as_str
get_cfg_as_str(cfg: SkyRLTrainConfig) -> strSource code in skyrl/train/entrypoints/main_base.py:147-149
@staticmethod
def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
return get_config_as_yaml_str(cfg)method get_train_dataset
get_train_dataset()Initializes the training dataset.
Returns:
| Name | Type | Description |
|---|---|---|
PromptDataset | The training dataset. |
Source code in skyrl/train/entrypoints/main_base.py:151-167
def get_train_dataset(self):
"""Initializes the training dataset.
Returns:
PromptDataset: The training dataset.
"""
prompts_dataset = PromptDataset(
datasets=self.cfg.data.train_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
# make sure the dataset is large enough to train on
assert (
len(prompts_dataset) >= self.cfg.trainer.train_batch_size
), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
return prompts_datasetmethod get_eval_dataset
get_eval_dataset()Initializes the evaluation dataset.
Returns:
| Name | Type | Description |
|---|---|---|
PromptDataset | The evaluation dataset. |
Source code in skyrl/train/entrypoints/main_base.py:169-183
def get_eval_dataset(self):
"""Initializes the evaluation dataset.
Returns:
PromptDataset: The evaluation dataset.
"""
if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
prompts_dataset = PromptDataset(
datasets=self.cfg.data.val_data,
tokenizer=self.tokenizer,
max_prompt_length=self.cfg.trainer.max_prompt_length,
num_workers=8,
)
return prompts_dataset
return Nonemethod get_colocate_pg
get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference engines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timeout | int | The timeout for the placement group to be ready. | SKYRL_RAY_PG_TIMEOUT_IN_S |
Returns:
| Type | Description |
|---|---|
| Optional[PlacementGroup] | A PlacementGroup when colocate_all is True, else None. |
Source code in skyrl/train/entrypoints/main_base.py:185-209
def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]:
"""Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference
engines.
Args:
timeout (int): The timeout for the placement group to be ready.
Returns:
A PlacementGroup when colocate_all is True, else None.
"""
if not self.cfg.trainer.placement.colocate_all:
return None
ie_cfg = self.cfg.generator.inference_engine
per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count
pg = placement_group(
[{"GPU": 1, "CPU": 1}] * total_gpu_slots,
strategy="PACK",
)
get_ray_pg_ready_with_timeout(pg, timeout=timeout)
return pgmethod get_generator
get_generator(cfg, tokenizer, inference_engine_client)Initializes the generator.
Returns:
| Name | Type | Description |
|---|---|---|
GeneratorInterface | The generator. |
Source code in skyrl/train/entrypoints/main_base.py:211-224
def get_generator(self, cfg, tokenizer, inference_engine_client):
"""Initializes the generator.
Returns:
GeneratorInterface: The generator.
"""
from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator
return SkyRLGymGenerator(
generator_cfg=cfg.generator,
skyrl_gym_cfg=cfg.environment.skyrl_gym,
inference_engine_client=inference_engine_client,
tokenizer=tokenizer,
)method get_trainer
get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)Initializes the trainer.
Returns:
| Name | Type | Description |
|---|---|---|
RayPPOTrainer | The trainer. |
Source code in skyrl/train/entrypoints/main_base.py:226-251
def get_trainer(
self,
cfg,
tracker,
tokenizer,
train_dataset,
eval_dataset,
inference_engine_client,
generator: GeneratorInterface,
colocate_pg,
):
"""Initializes the trainer.
Returns:
RayPPOTrainer: The trainer.
"""
return RayPPOTrainer(
cfg=cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=colocate_pg,
)method get_tracker
get_tracker()Initializes the tracker for experiment tracking.
Returns:
| Name | Type | Description |
|---|---|---|
Tracking | The tracker. |
Source code in skyrl/train/entrypoints/main_base.py:253-264
def get_tracker(self):
"""Initializes the tracker for experiment tracking.
Returns:
Tracking: The tracker.
"""
return Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
config=self.cfg,
)method get_inference_client
get_inference_client() -> InferenceEngineInterfaceSetup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).
Returns:
| Name | Type | Description |
|---|---|---|
InferenceEngineInterface | InferenceEngineInterface | The inference engine client. |
Source code in skyrl/train/entrypoints/main_base.py:266-279
def get_inference_client(self) -> InferenceEngineInterface:
"""Setup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize
inference engine creation (e.g., FlashRL, custom backends).
Returns:
InferenceEngineInterface: The inference engine client.
"""
if _SKYRL_USE_NEW_INFERENCE:
logger.info("Initializing new inference client")
return self._get_new_inference_client()
else:
return self._get_legacy_inference_client()method run
run()Source code in skyrl/train/entrypoints/main_base.py:430-433
def run(self):
trainer = self._setup_trainer()
# Start the training loop
asyncio.run(trainer.train())Evaluation Entrypoint
The evaluation-only entrypoint is the EvalOnlyEntrypoint class which runs evaluation without training.
class EvalOnlyEntrypoint
Bases: BasePPOExp
Functions:
| Name | Description |
|---|---|
get_train_dataset | Override to avoid requiring a train dataset for eval-only runs. |
run | |
get_cfg_as_str | |
get_eval_dataset | Initializes the evaluation dataset. |
get_colocate_pg | Initializes a placement group for colocated training. |
get_generator | Initializes the generator. |
get_trainer | Initializes the trainer. |
get_tracker | Initializes the tracker for experiment tracking. |
get_inference_client | Setup and return the inference engine client. |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
tokenizer | ||
train_dataset | ||
eval_dataset | ||
colocate_pg |
Source code in skyrl/train/entrypoints/main_generate.py:21-44
class EvalOnlyEntrypoint(BasePPOExp):
def get_train_dataset(self):
"""Override to avoid requiring a train dataset for eval-only runs."""
return None
async def run(self) -> dict[str, Any]:
assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"
inference_engine_client = self.get_inference_client()
await inference_engine_client.wake_up()
generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)
results: dict[str, Any] = await evaluate(
eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
generator=generator,
cfg=self.cfg,
global_step=None,
tokenizer=self.tokenizer,
)
tracker = self.get_tracker()
tracker.log(results, step=0, commit=True)
return resultsmethod get_train_dataset
get_train_dataset()Override to avoid requiring a train dataset for eval-only runs.
Source code in skyrl/train/entrypoints/main_generate.py:22-24
def get_train_dataset(self):
"""Override to avoid requiring a train dataset for eval-only runs."""
return Nonemethod run
run() -> dict[str, Any]Source code in skyrl/train/entrypoints/main_generate.py:26-44
async def run(self) -> dict[str, Any]:
assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"
inference_engine_client = self.get_inference_client()
await inference_engine_client.wake_up()
generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)
results: dict[str, Any] = await evaluate(
eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
generator=generator,
cfg=self.cfg,
global_step=None,
tokenizer=self.tokenizer,
)
tracker = self.get_tracker()
tracker.log(results, step=0, commit=True)
return resultsattr cfg
cfg = cfgattr tokenizer
tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')attr train_dataset
train_dataset = self.get_train_dataset()attr eval_dataset
eval_dataset = self.get_eval_dataset()attr colocate_pg
colocate_pg = self.get_colocate_pg()method staticmethod get_cfg_as_str
get_cfg_as_str(cfg: SkyRLTrainConfig) -> strmethod get_eval_dataset
get_eval_dataset()Initializes the evaluation dataset.
Returns:
| Name | Type | Description |
|---|---|---|
PromptDataset | The evaluation dataset. |
method get_colocate_pg
get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]Initializes a placement group for colocated training.
Creates a single placement group with per-GPU bundles for all inference engines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timeout | int | The timeout for the placement group to be ready. | SKYRL_RAY_PG_TIMEOUT_IN_S |
Returns:
| Type | Description |
|---|---|
| Optional[PlacementGroup] | A PlacementGroup when colocate_all is True, else None. |
method get_generator
get_generator(cfg, tokenizer, inference_engine_client)Initializes the generator.
Returns:
| Name | Type | Description |
|---|---|---|
GeneratorInterface | The generator. |
method get_trainer
get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)Initializes the trainer.
Returns:
| Name | Type | Description |
|---|---|---|
RayPPOTrainer | The trainer. |
method get_tracker
get_tracker()Initializes the tracker for experiment tracking.
Returns:
| Name | Type | Description |
|---|---|---|
Tracking | The tracker. |
method get_inference_client
get_inference_client() -> InferenceEngineInterfaceSetup and return the inference engine client.
This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).
Returns:
| Name | Type | Description |
|---|---|---|
InferenceEngineInterface | InferenceEngineInterface | The inference engine client. |