SkyRL
API ReferenceSkyRL

Entrypoint

Entrypoint API — BasePPOExp, EvalOnlyEntrypoint.

Training Entrypoint

The main entrypoint is the BasePPOExp class which runs the main training loop.

class BasePPOExp

BasePPOExp(cfg: SkyRLTrainConfig)

Functions:

NameDescription
get_cfg_as_str
get_train_datasetInitializes the training dataset.
get_eval_datasetInitializes the evaluation dataset.
get_colocate_pgInitializes a placement group for colocated training.
get_generatorInitializes the generator.
get_trainerInitializes the trainer.
get_trackerInitializes the tracker for experiment tracking.
get_inference_clientSetup and return the inference engine client.
run

Attributes:

Initializes a PPO experiment.

Parameters:

NameTypeDescriptionDefault
cfgSkyRLTrainConfigThe fully resolved SkyRLTrainConfig instance.required
Source code in skyrl/train/entrypoints/main_base.py:124-433
class BasePPOExp:
    def __init__(self, cfg: SkyRLTrainConfig):
        """
        Initializes a PPO experiment.

        Args:
            cfg: The fully resolved SkyRLTrainConfig instance.
        """
        self.cfg = cfg
        self.tokenizer = get_tokenizer(
            self.cfg.trainer.policy.model.path,
            trust_remote_code=True,
            use_fast=not self.cfg.trainer.disable_fast_tokenizer,
            padding_side="left",
        )
        self.train_dataset = self.get_train_dataset()
        self.eval_dataset = self.get_eval_dataset()
        self.colocate_pg = self.get_colocate_pg()

        # New inference resources (created lazily when _SKYRL_USE_NEW_INFERENCE=1)
        self._server_group = None
        self._inference_router = None

    @staticmethod
    def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
        return get_config_as_yaml_str(cfg)

    def get_train_dataset(self):
        """Initializes the training dataset.

        Returns:
            PromptDataset: The training dataset.
        """
        prompts_dataset = PromptDataset(
            datasets=self.cfg.data.train_data,
            tokenizer=self.tokenizer,
            max_prompt_length=self.cfg.trainer.max_prompt_length,
            num_workers=8,
        )
        # make sure the dataset is large enough to train on
        assert (
            len(prompts_dataset) >= self.cfg.trainer.train_batch_size
        ), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
        return prompts_dataset

    def get_eval_dataset(self):
        """Initializes the evaluation dataset.

        Returns:
            PromptDataset: The evaluation dataset.
        """
        if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
            prompts_dataset = PromptDataset(
                datasets=self.cfg.data.val_data,
                tokenizer=self.tokenizer,
                max_prompt_length=self.cfg.trainer.max_prompt_length,
                num_workers=8,
            )
            return prompts_dataset
        return None

    def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]:
        """Initializes a placement group for colocated training.

        Creates a single placement group with per-GPU bundles for all inference
        engines.

        Args:
            timeout (int): The timeout for the placement group to be ready.

        Returns:
            A PlacementGroup when colocate_all is True, else None.
        """
        if not self.cfg.trainer.placement.colocate_all:
            return None

        ie_cfg = self.cfg.generator.inference_engine
        per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
        total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count

        pg = placement_group(
            [{"GPU": 1, "CPU": 1}] * total_gpu_slots,
            strategy="PACK",
        )
        get_ray_pg_ready_with_timeout(pg, timeout=timeout)
        return pg

    def get_generator(self, cfg, tokenizer, inference_engine_client):
        """Initializes the generator.

        Returns:
            GeneratorInterface: The generator.
        """
        from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator

        return SkyRLGymGenerator(
            generator_cfg=cfg.generator,
            skyrl_gym_cfg=cfg.environment.skyrl_gym,
            inference_engine_client=inference_engine_client,
            tokenizer=tokenizer,
        )

    def get_trainer(
        self,
        cfg,
        tracker,
        tokenizer,
        train_dataset,
        eval_dataset,
        inference_engine_client,
        generator: GeneratorInterface,
        colocate_pg,
    ):
        """Initializes the trainer.

        Returns:
            RayPPOTrainer: The trainer.
        """
        return RayPPOTrainer(
            cfg=cfg,
            tracker=tracker,
            tokenizer=tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            inference_engine_client=inference_engine_client,
            generator=generator,
            colocate_pg=colocate_pg,
        )

    def get_tracker(self):
        """Initializes the tracker for experiment tracking.

        Returns:
            Tracking: The tracker.
        """
        return Tracking(
            project_name=self.cfg.trainer.project_name,
            experiment_name=self.cfg.trainer.run_name,
            backends=self.cfg.trainer.logger,
            config=self.cfg,
        )

    def get_inference_client(self) -> InferenceEngineInterface:
        """Setup and return the inference engine client.

        This is a hook method that can be overridden by subclasses to customize
        inference engine creation (e.g., FlashRL, custom backends).

        Returns:
            InferenceEngineInterface: The inference engine client.
        """
        if _SKYRL_USE_NEW_INFERENCE:
            logger.info("Initializing new inference client")
            return self._get_new_inference_client()
        else:
            return self._get_legacy_inference_client()

    def _get_legacy_inference_client(self) -> InferenceEngineInterface:
        """Legacy inference client using Ray actors."""
        if self.cfg.generator.inference_engine.run_engines_locally:
            inference_engines = create_ray_wrapped_inference_engines_from_config(
                self.cfg, self.colocate_pg, self.tokenizer
            )
        else:
            inference_engines = create_remote_inference_engines_from_config(self.cfg, self.tokenizer)

        return InferenceEngineClient(
            inference_engines,
            self.tokenizer,
            self.cfg.trainer.policy.model.path,
            self.cfg.trainer.policy.model.lora,
            self.cfg.generator.inference_engine,
        )

    def _get_new_inference_client(self):
        """New inference client using HTTP endpoints.

        Config combinations:
        - Colocated + external URLs → ERROR (validated earlier)
        - Neither set → Build servers internally
        - external_server_urls only → Create router over external servers
        - external_proxy_url only → Use proxy for both data + control plane
        - Both set → Fully external (proxy for data plane, servers for control plane)

        Returns:
            RemoteInferenceClient: The new inference client.
        """
        from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import (
            RemoteInferenceClient,
        )
        from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter
        from skyrl.backends.skyrl_train.inference_servers.server_group import (
            ServerGroup,
        )

        ie_cfg = self.cfg.generator.inference_engine
        is_colocated = self.cfg.trainer.placement.colocate_all
        external_proxy_url = ie_cfg.external_proxy_url
        external_server_urls = ie_cfg.external_server_urls

        has_external_proxy = external_proxy_url is not None
        has_external_servers = external_server_urls is not None

        if has_external_proxy and has_external_servers:
            # Case: Both external - fully external setup
            proxy_url = external_proxy_url
            server_urls = list(external_server_urls)
            logger.info(
                f"HTTP Inference: Using fully external setup - " f"proxy_url={proxy_url}, server_urls={server_urls}"
            )

        elif has_external_proxy and not has_external_servers:
            # Case: Proxy only - assume proxy handles control plane too
            proxy_url = external_proxy_url
            server_urls = [proxy_url]
            logger.info(
                f"HTTP Inference: Using external proxy for both data and " f"control plane - proxy_url={proxy_url}"
            )

        elif has_external_servers and not has_external_proxy:
            # Case: Servers only - create internal router over them
            server_urls = list(external_server_urls)
            self._inference_router = InferenceRouter(server_urls=server_urls)
            proxy_url = self._inference_router.start()
            logger.info(
                f"HTTP Inference: Created internal router over external "
                f"servers - server_urls={server_urls}, proxy_url={proxy_url}"
            )

        else:
            # Case: Neither - build servers and router internally
            cli_args = build_vllm_cli_args(self.cfg)

            self._server_group = ServerGroup(
                cli_args=cli_args,
                num_servers=ie_cfg.num_engines,
                placement_group=self.colocate_pg if is_colocated else None,
                enable_dp=ie_cfg.data_parallel_size > 1,
                distributed_executor_backend=ie_cfg.distributed_executor_backend,
            )
            server_infos = self._server_group.start()
            server_urls = [info.url for info in server_infos]

            self._inference_router = InferenceRouter(server_urls=server_urls)
            proxy_url = self._inference_router.start()
            logger.info(
                f"HTTP Inference: Built servers and router internally - "
                f"proxy_url={proxy_url}, server_urls={server_urls}, colocated={is_colocated}"
            )

        return RemoteInferenceClient(
            proxy_url=proxy_url,
            server_urls=server_urls,
            model_name=self.cfg.trainer.policy.model.path,
        )

    def _setup_trainer(self):
        """Setup and return the trainer.

        Instantiates the trainer and all the associated models for training.

        Returns:
            RayPPOTrainer: The trainer.
        """
        logger.info(self.get_cfg_as_str(self.cfg))
        os.makedirs(self.cfg.trainer.export_path, exist_ok=True)
        os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True)

        if self.cfg.trainer.strategy in ("fsdp", "fsdp2"):
            from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import (
                CriticWorker,
                PolicyWorker,
                RefWorker,
            )
        elif self.cfg.trainer.strategy == "megatron":
            from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
                CriticWorker,
                PolicyWorker,
                RefWorker,
            )
        else:
            raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}")

        # NOTE (sumanthrh): Instantiate tracker before trainer init.
        # We have custom validation before this step to give better error messages.
        tracker = self.get_tracker()

        inference_engine_client = self.get_inference_client()

        generator: GeneratorInterface = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)

        trainer = self.get_trainer(
            cfg=self.cfg,
            tracker=tracker,
            tokenizer=self.tokenizer,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            inference_engine_client=inference_engine_client,
            generator=generator,
            colocate_pg=self.colocate_pg,
        )

        # Build the models
        trainer.build_models(PolicyWorker, CriticWorker, RefWorker)
        return trainer

    def run(self):
        trainer = self._setup_trainer()
        # Start the training loop
        asyncio.run(trainer.train())

attr cfg

cfg = cfg

attr tokenizer

tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')

attr train_dataset

train_dataset = self.get_train_dataset()

attr eval_dataset

eval_dataset = self.get_eval_dataset()

attr colocate_pg

colocate_pg = self.get_colocate_pg()

method staticmethod get_cfg_as_str

get_cfg_as_str(cfg: SkyRLTrainConfig) -> str
Source code in skyrl/train/entrypoints/main_base.py:147-149
    @staticmethod
    def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
        return get_config_as_yaml_str(cfg)

method get_train_dataset

get_train_dataset()

Initializes the training dataset.

Returns:

NameTypeDescription
PromptDatasetThe training dataset.
Source code in skyrl/train/entrypoints/main_base.py:151-167
    def get_train_dataset(self):
        """Initializes the training dataset.

        Returns:
            PromptDataset: The training dataset.
        """
        prompts_dataset = PromptDataset(
            datasets=self.cfg.data.train_data,
            tokenizer=self.tokenizer,
            max_prompt_length=self.cfg.trainer.max_prompt_length,
            num_workers=8,
        )
        # make sure the dataset is large enough to train on
        assert (
            len(prompts_dataset) >= self.cfg.trainer.train_batch_size
        ), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
        return prompts_dataset

method get_eval_dataset

get_eval_dataset()

Initializes the evaluation dataset.

Returns:

NameTypeDescription
PromptDatasetThe evaluation dataset.
Source code in skyrl/train/entrypoints/main_base.py:169-183
    def get_eval_dataset(self):
        """Initializes the evaluation dataset.

        Returns:
            PromptDataset: The evaluation dataset.
        """
        if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
            prompts_dataset = PromptDataset(
                datasets=self.cfg.data.val_data,
                tokenizer=self.tokenizer,
                max_prompt_length=self.cfg.trainer.max_prompt_length,
                num_workers=8,
            )
            return prompts_dataset
        return None

method get_colocate_pg

get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]

Initializes a placement group for colocated training.

Creates a single placement group with per-GPU bundles for all inference engines.

Parameters:

NameTypeDescriptionDefault
timeoutintThe timeout for the placement group to be ready.SKYRL_RAY_PG_TIMEOUT_IN_S

Returns:

TypeDescription
Optional[PlacementGroup]A PlacementGroup when colocate_all is True, else None.
Source code in skyrl/train/entrypoints/main_base.py:185-209
    def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]:
        """Initializes a placement group for colocated training.

        Creates a single placement group with per-GPU bundles for all inference
        engines.

        Args:
            timeout (int): The timeout for the placement group to be ready.

        Returns:
            A PlacementGroup when colocate_all is True, else None.
        """
        if not self.cfg.trainer.placement.colocate_all:
            return None

        ie_cfg = self.cfg.generator.inference_engine
        per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
        total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count

        pg = placement_group(
            [{"GPU": 1, "CPU": 1}] * total_gpu_slots,
            strategy="PACK",
        )
        get_ray_pg_ready_with_timeout(pg, timeout=timeout)
        return pg

method get_generator

get_generator(cfg, tokenizer, inference_engine_client)

Initializes the generator.

Returns:

NameTypeDescription
GeneratorInterfaceThe generator.
Source code in skyrl/train/entrypoints/main_base.py:211-224
    def get_generator(self, cfg, tokenizer, inference_engine_client):
        """Initializes the generator.

        Returns:
            GeneratorInterface: The generator.
        """
        from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator

        return SkyRLGymGenerator(
            generator_cfg=cfg.generator,
            skyrl_gym_cfg=cfg.environment.skyrl_gym,
            inference_engine_client=inference_engine_client,
            tokenizer=tokenizer,
        )

method get_trainer

get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)

Initializes the trainer.

Returns:

NameTypeDescription
RayPPOTrainerThe trainer.
Source code in skyrl/train/entrypoints/main_base.py:226-251
    def get_trainer(
        self,
        cfg,
        tracker,
        tokenizer,
        train_dataset,
        eval_dataset,
        inference_engine_client,
        generator: GeneratorInterface,
        colocate_pg,
    ):
        """Initializes the trainer.

        Returns:
            RayPPOTrainer: The trainer.
        """
        return RayPPOTrainer(
            cfg=cfg,
            tracker=tracker,
            tokenizer=tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            inference_engine_client=inference_engine_client,
            generator=generator,
            colocate_pg=colocate_pg,
        )

method get_tracker

get_tracker()

Initializes the tracker for experiment tracking.

Returns:

NameTypeDescription
TrackingThe tracker.
Source code in skyrl/train/entrypoints/main_base.py:253-264
    def get_tracker(self):
        """Initializes the tracker for experiment tracking.

        Returns:
            Tracking: The tracker.
        """
        return Tracking(
            project_name=self.cfg.trainer.project_name,
            experiment_name=self.cfg.trainer.run_name,
            backends=self.cfg.trainer.logger,
            config=self.cfg,
        )

method get_inference_client

get_inference_client() -> InferenceEngineInterface

Setup and return the inference engine client.

This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).

Returns:

NameTypeDescription
InferenceEngineInterfaceInferenceEngineInterfaceThe inference engine client.
Source code in skyrl/train/entrypoints/main_base.py:266-279
    def get_inference_client(self) -> InferenceEngineInterface:
        """Setup and return the inference engine client.

        This is a hook method that can be overridden by subclasses to customize
        inference engine creation (e.g., FlashRL, custom backends).

        Returns:
            InferenceEngineInterface: The inference engine client.
        """
        if _SKYRL_USE_NEW_INFERENCE:
            logger.info("Initializing new inference client")
            return self._get_new_inference_client()
        else:
            return self._get_legacy_inference_client()

method run

run()
Source code in skyrl/train/entrypoints/main_base.py:430-433
    def run(self):
        trainer = self._setup_trainer()
        # Start the training loop
        asyncio.run(trainer.train())

Evaluation Entrypoint

The evaluation-only entrypoint is the EvalOnlyEntrypoint class which runs evaluation without training.

class EvalOnlyEntrypoint

Bases: BasePPOExp

Functions:

NameDescription
get_train_datasetOverride to avoid requiring a train dataset for eval-only runs.
run
get_cfg_as_str
get_eval_datasetInitializes the evaluation dataset.
get_colocate_pgInitializes a placement group for colocated training.
get_generatorInitializes the generator.
get_trainerInitializes the trainer.
get_trackerInitializes the tracker for experiment tracking.
get_inference_clientSetup and return the inference engine client.

Attributes:

Source code in skyrl/train/entrypoints/main_generate.py:21-44
class EvalOnlyEntrypoint(BasePPOExp):
    def get_train_dataset(self):
        """Override to avoid requiring a train dataset for eval-only runs."""
        return None

    async def run(self) -> dict[str, Any]:
        assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"

        inference_engine_client = self.get_inference_client()
        await inference_engine_client.wake_up()
        generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)

        results: dict[str, Any] = await evaluate(
            eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
            generator=generator,
            cfg=self.cfg,
            global_step=None,
            tokenizer=self.tokenizer,
        )

        tracker = self.get_tracker()
        tracker.log(results, step=0, commit=True)

        return results

method get_train_dataset

get_train_dataset()

Override to avoid requiring a train dataset for eval-only runs.

Source code in skyrl/train/entrypoints/main_generate.py:22-24
    def get_train_dataset(self):
        """Override to avoid requiring a train dataset for eval-only runs."""
        return None

method run

run() -> dict[str, Any]
Source code in skyrl/train/entrypoints/main_generate.py:26-44
    async def run(self) -> dict[str, Any]:
        assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"

        inference_engine_client = self.get_inference_client()
        await inference_engine_client.wake_up()
        generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)

        results: dict[str, Any] = await evaluate(
            eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
            generator=generator,
            cfg=self.cfg,
            global_step=None,
            tokenizer=self.tokenizer,
        )

        tracker = self.get_tracker()
        tracker.log(results, step=0, commit=True)

        return results

attr cfg

cfg = cfg

attr tokenizer

tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')

attr train_dataset

train_dataset = self.get_train_dataset()

attr eval_dataset

eval_dataset = self.get_eval_dataset()

attr colocate_pg

colocate_pg = self.get_colocate_pg()

method staticmethod get_cfg_as_str

get_cfg_as_str(cfg: SkyRLTrainConfig) -> str

method get_eval_dataset

get_eval_dataset()

Initializes the evaluation dataset.

Returns:

NameTypeDescription
PromptDatasetThe evaluation dataset.

method get_colocate_pg

get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[PlacementGroup]

Initializes a placement group for colocated training.

Creates a single placement group with per-GPU bundles for all inference engines.

Parameters:

NameTypeDescriptionDefault
timeoutintThe timeout for the placement group to be ready.SKYRL_RAY_PG_TIMEOUT_IN_S

Returns:

TypeDescription
Optional[PlacementGroup]A PlacementGroup when colocate_all is True, else None.

method get_generator

get_generator(cfg, tokenizer, inference_engine_client)

Initializes the generator.

Returns:

NameTypeDescription
GeneratorInterfaceThe generator.

method get_trainer

get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)

Initializes the trainer.

Returns:

NameTypeDescription
RayPPOTrainerThe trainer.

method get_tracker

get_tracker()

Initializes the tracker for experiment tracking.

Returns:

NameTypeDescription
TrackingThe tracker.

method get_inference_client

get_inference_client() -> InferenceEngineInterface

Setup and return the inference engine client.

This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).

Returns:

NameTypeDescription
InferenceEngineInterfaceInferenceEngineInterfaceThe inference engine client.

On this page