SkyRL
API ReferenceSkyRL

Entrypoint

Entrypoint API — BasePPOExp, EvalOnlyEntrypoint.

Training Entrypoint

The main entrypoint is the BasePPOExp class which runs the main training loop.

class BasePPOExp

BasePPOExp(cfg: SkyRLTrainConfig)

Functions:

NameDescription
get_cfg_as_str
get_train_datasetInitializes the training dataset.
get_eval_datasetInitializes the evaluation dataset.
get_colocate_pgInitializes a placement group for colocated training.
get_generatorInitializes the generator.
get_trainerInitializes the trainer.
get_trackerInitializes the tracker for experiment tracking.
get_inference_clientSetup and return the inference engine client.
run

Attributes:

Initializes a PPO experiment.

Parameters:

NameTypeDescriptionDefault
cfgSkyRLTrainConfigThe fully resolved SkyRLTrainConfig instance.required
Source code in skyrl/train/entrypoints/main_base.py:133-414
class BasePPOExp:
    def __init__(self, cfg: SkyRLTrainConfig):
        """
        Initializes a PPO experiment.

        Args:
            cfg: The fully resolved SkyRLTrainConfig instance.
        """
        self.cfg = cfg
        self.tokenizer = get_tokenizer(
            self.cfg.trainer.policy.model.path,
            trust_remote_code=True,
            use_fast=not self.cfg.trainer.disable_fast_tokenizer,
            padding_side="left",
        )
        self.train_dataset = self.get_train_dataset()
        self.eval_dataset = self.get_eval_dataset()
        self.colocate_pg = self.get_colocate_pg()

        # New inference resources (created lazily when _SKYRL_USE_NEW_INFERENCE=1)
        self._server_groups = None
        self._prefill_server_groups = None
        self._decode_server_groups = None
        self._inference_router = None

    @staticmethod
    def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
        return get_config_as_yaml_str(cfg)

    def get_train_dataset(self):
        """Initializes the training dataset.

        Returns:
            PromptDataset: The training dataset.
        """
        prompts_dataset = PromptDataset(
            datasets=self.cfg.data.train_data,
            tokenizer=self.tokenizer,
            max_prompt_length=self.cfg.trainer.max_prompt_length,
            num_workers=8,
        )
        # make sure the dataset is large enough to train on
        assert (
            len(prompts_dataset) >= self.cfg.trainer.train_batch_size
        ), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
        return prompts_dataset

    def get_eval_dataset(self):
        """Initializes the evaluation dataset.

        Returns:
            PromptDataset: The evaluation dataset.
        """
        if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
            prompts_dataset = PromptDataset(
                datasets=self.cfg.data.val_data,
                tokenizer=self.tokenizer,
                max_prompt_length=self.cfg.trainer.max_prompt_length,
                num_workers=8,
            )
            return prompts_dataset
        return None

    def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]:
        """Initializes a placement group for colocated training.

        Creates a single placement group with per-GPU bundles for all inference
        engines. The returned wrapper computes GPU-aware bundle ordering at init time.

        Args:
            timeout (int): The timeout for the placement group to be ready.

        Returns:
            ResolvedPlacementGroup: The placement group wrapper for colocated training, or None.
        """
        if not self.cfg.trainer.placement.colocate_all:
            return None

        ie_cfg = self.cfg.generator.inference_engine
        per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
        total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count

        pg = placement_group(
            [{"GPU": 1, "CPU": 1}] * total_gpu_slots,
            strategy="PACK",
        )
        get_ray_pg_ready_with_timeout(pg, timeout=timeout)
        return ResolvedPlacementGroup(pg)

    def get_generator(self, cfg, tokenizer, inference_engine_client):
        """Initializes the generator.

        Returns:
            GeneratorInterface: The generator.
        """
        if cfg.generator.vision_language_generator:
            from skyrl.train.generators.skyrl_vlm_generator import SkyRLVLMGymGenerator

            generator_cls = SkyRLVLMGymGenerator
        else:
            from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator

            generator_cls = SkyRLGymGenerator

        return generator_cls(
            generator_cfg=cfg.generator,
            skyrl_gym_cfg=cfg.environment.skyrl_gym,
            inference_engine_client=inference_engine_client,
            tokenizer=tokenizer,
            policy_model_name=resolve_policy_model_name(cfg),
        )

    def get_trainer(
        self,
        cfg,
        tracker,
        tokenizer,
        train_dataset,
        eval_dataset,
        inference_engine_client,
        generator: GeneratorInterface,
        colocate_pg,
    ):
        """Initializes the trainer.

        Returns:
            RayPPOTrainer: The trainer.
        """
        return RayPPOTrainer(
            cfg=cfg,
            tracker=tracker,
            tokenizer=tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            inference_engine_client=inference_engine_client,
            generator=generator,
            colocate_pg=colocate_pg,
        )

    def get_tracker(self):
        """Initializes the tracker for experiment tracking.

        Returns:
            Tracking: The tracker.
        """
        return Tracking(
            project_name=self.cfg.trainer.project_name,
            experiment_name=self.cfg.trainer.run_name,
            backends=self.cfg.trainer.logger,
            config=self.cfg,
            tags=self.cfg.trainer.tags,
        )

    def get_inference_client(self) -> InferenceEngineInterface:
        """Setup and return the inference engine client.

        This is a hook method that can be overridden by subclasses to customize
        inference engine creation (e.g., FlashRL, custom backends).

        Returns:
            InferenceEngineInterface: The inference engine client.
        """
        if _SKYRL_USE_NEW_INFERENCE:
            logger.info("Initializing new inference client")
            return self._get_new_inference_client()
        else:
            return self._get_legacy_inference_client()

    def _get_legacy_inference_client(self) -> InferenceEngineInterface:
        """Legacy inference client using Ray actors."""
        if self.cfg.generator.inference_engine.run_engines_locally:
            inference_engines = create_ray_wrapped_inference_engines_from_config(
                self.cfg, self.colocate_pg, self.tokenizer
            )
        else:
            inference_engines = create_remote_inference_engines_from_config(self.cfg, self.tokenizer)

        return InferenceEngineClient(
            inference_engines,
            self.tokenizer,
            self.cfg.trainer.policy.model.path,
            self.cfg.trainer.policy.model.lora,
            self.cfg.generator.inference_engine,
        )

    def _get_new_inference_client(self):
        """New inference client using HTTP endpoints.

        Returns:
            RemoteInferenceClient: The new inference client.
        """
        from skyrl.backends.skyrl_train.inference_servers.setup import (
            build_new_inference_client,
        )

        is_colocated = self.cfg.trainer.placement.colocate_all
        client, server_setup = build_new_inference_client(
            self.cfg,
            self.tokenizer,
            placement_group=self.colocate_pg if is_colocated else None,
        )
        self._inference_router = server_setup.router
        self._server_groups = server_setup.server_groups
        self._prefill_server_groups = server_setup.prefill_server_groups
        self._decode_server_groups = server_setup.decode_server_groups

        if is_colocated:
            # Callers must invoke get_inference_client() from a sync context (no running event loop).
            asyncio.run(client.sleep())
            logger.info("HTTP Inference: Colocated mode - slept inference engines after startup")

        return client

    def _setup_trainer(self):
        """Setup and return the trainer.

        Instantiates the trainer and all the associated models for training.

        Returns:
            RayPPOTrainer: The trainer.
        """
        logger.info(self.get_cfg_as_str(self.cfg))
        os.makedirs(self.cfg.trainer.export_path, exist_ok=True)
        os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True)

        if self.cfg.trainer.strategy == "fsdp":
            from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import (
                CriticWorker,
                PolicyWorker,
                RefWorker,
            )
        elif self.cfg.trainer.strategy == "megatron":
            from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
                CriticWorker,
                PolicyWorker,
                RefWorker,
            )
        else:
            raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}")

        # NOTE (sumanthrh): Instantiate tracker before trainer init.
        # We have custom validation before this step to give better error messages.
        tracker = self.get_tracker()

        inference_engine_client = self.get_inference_client()

        generator: GeneratorInterface = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)

        trainer = self.get_trainer(
            cfg=self.cfg,
            tracker=tracker,
            tokenizer=self.tokenizer,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            inference_engine_client=inference_engine_client,
            generator=generator,
            colocate_pg=self.colocate_pg,
        )
        # Expose the trainer on self so callers can log exceptions raised
        # during `build_models` (which happens before _setup_trainer returns).
        self.trainer = trainer

        # Build the models
        trainer.build_models(PolicyWorker, CriticWorker, RefWorker)
        return trainer

    def run(self):
        self.trainer = None
        try:
            trainer = self._setup_trainer()
            # Start the training loop
            asyncio.run(trainer.train())
        except Exception as e:
            # OOMs raised inside actor init (e.g. FSDPPolicyWorkerBase.init_model)
            # surface here as RayTaskError. Without this they only land in Ray
            # worker logs; route them through the tracker so wandb users see
            # them as an `error/tracebacks` table row.
            if self.trainer is not None and self.trainer.tracker is not None:
                self.trainer.tracker.log_exception(e, step=self.trainer.global_step)
            else:
                logger.error(f"Setup failed before tracker was initialized:\n{e}")
            raise

attr cfg

cfg = cfg

attr tokenizer

tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')

attr train_dataset

train_dataset = self.get_train_dataset()

attr eval_dataset

eval_dataset = self.get_eval_dataset()

attr colocate_pg

colocate_pg = self.get_colocate_pg()

method staticmethod get_cfg_as_str

get_cfg_as_str(cfg: SkyRLTrainConfig) -> str
Source code in skyrl/train/entrypoints/main_base.py:158-160
    @staticmethod
    def get_cfg_as_str(cfg: SkyRLTrainConfig) -> str:
        return get_config_as_yaml_str(cfg)

method get_train_dataset

get_train_dataset()

Initializes the training dataset.

Returns:

NameTypeDescription
PromptDatasetThe training dataset.
Source code in skyrl/train/entrypoints/main_base.py:162-178
    def get_train_dataset(self):
        """Initializes the training dataset.

        Returns:
            PromptDataset: The training dataset.
        """
        prompts_dataset = PromptDataset(
            datasets=self.cfg.data.train_data,
            tokenizer=self.tokenizer,
            max_prompt_length=self.cfg.trainer.max_prompt_length,
            num_workers=8,
        )
        # make sure the dataset is large enough to train on
        assert (
            len(prompts_dataset) >= self.cfg.trainer.train_batch_size
        ), f"dataset should be at least as large as `train_batch_size` {self.cfg.trainer.train_batch_size}, got size {len(prompts_dataset)}"
        return prompts_dataset

method get_eval_dataset

get_eval_dataset()

Initializes the evaluation dataset.

Returns:

NameTypeDescription
PromptDatasetThe evaluation dataset.
Source code in skyrl/train/entrypoints/main_base.py:180-194
    def get_eval_dataset(self):
        """Initializes the evaluation dataset.

        Returns:
            PromptDataset: The evaluation dataset.
        """
        if self.cfg.trainer.eval_interval > 0 and self.cfg.data.val_data:
            prompts_dataset = PromptDataset(
                datasets=self.cfg.data.val_data,
                tokenizer=self.tokenizer,
                max_prompt_length=self.cfg.trainer.max_prompt_length,
                num_workers=8,
            )
            return prompts_dataset
        return None

method get_colocate_pg

get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]

Initializes a placement group for colocated training.

Creates a single placement group with per-GPU bundles for all inference engines. The returned wrapper computes GPU-aware bundle ordering at init time.

Parameters:

NameTypeDescriptionDefault
timeoutintThe timeout for the placement group to be ready.SKYRL_RAY_PG_TIMEOUT_IN_S

Returns:

NameTypeDescription
ResolvedPlacementGroupOptional[ResolvedPlacementGroup]The placement group wrapper for colocated training, or None.
Source code in skyrl/train/entrypoints/main_base.py:196-220
    def get_colocate_pg(self, timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]:
        """Initializes a placement group for colocated training.

        Creates a single placement group with per-GPU bundles for all inference
        engines. The returned wrapper computes GPU-aware bundle ordering at init time.

        Args:
            timeout (int): The timeout for the placement group to be ready.

        Returns:
            ResolvedPlacementGroup: The placement group wrapper for colocated training, or None.
        """
        if not self.cfg.trainer.placement.colocate_all:
            return None

        ie_cfg = self.cfg.generator.inference_engine
        per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
        total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count

        pg = placement_group(
            [{"GPU": 1, "CPU": 1}] * total_gpu_slots,
            strategy="PACK",
        )
        get_ray_pg_ready_with_timeout(pg, timeout=timeout)
        return ResolvedPlacementGroup(pg)

method get_generator

get_generator(cfg, tokenizer, inference_engine_client)

Initializes the generator.

Returns:

NameTypeDescription
GeneratorInterfaceThe generator.
Source code in skyrl/train/entrypoints/main_base.py:222-243
    def get_generator(self, cfg, tokenizer, inference_engine_client):
        """Initializes the generator.

        Returns:
            GeneratorInterface: The generator.
        """
        if cfg.generator.vision_language_generator:
            from skyrl.train.generators.skyrl_vlm_generator import SkyRLVLMGymGenerator

            generator_cls = SkyRLVLMGymGenerator
        else:
            from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator

            generator_cls = SkyRLGymGenerator

        return generator_cls(
            generator_cfg=cfg.generator,
            skyrl_gym_cfg=cfg.environment.skyrl_gym,
            inference_engine_client=inference_engine_client,
            tokenizer=tokenizer,
            policy_model_name=resolve_policy_model_name(cfg),
        )

method get_trainer

get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)

Initializes the trainer.

Returns:

NameTypeDescription
RayPPOTrainerThe trainer.
Source code in skyrl/train/entrypoints/main_base.py:245-270
    def get_trainer(
        self,
        cfg,
        tracker,
        tokenizer,
        train_dataset,
        eval_dataset,
        inference_engine_client,
        generator: GeneratorInterface,
        colocate_pg,
    ):
        """Initializes the trainer.

        Returns:
            RayPPOTrainer: The trainer.
        """
        return RayPPOTrainer(
            cfg=cfg,
            tracker=tracker,
            tokenizer=tokenizer,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            inference_engine_client=inference_engine_client,
            generator=generator,
            colocate_pg=colocate_pg,
        )

method get_tracker

get_tracker()

Initializes the tracker for experiment tracking.

Returns:

NameTypeDescription
TrackingThe tracker.
Source code in skyrl/train/entrypoints/main_base.py:272-284
    def get_tracker(self):
        """Initializes the tracker for experiment tracking.

        Returns:
            Tracking: The tracker.
        """
        return Tracking(
            project_name=self.cfg.trainer.project_name,
            experiment_name=self.cfg.trainer.run_name,
            backends=self.cfg.trainer.logger,
            config=self.cfg,
            tags=self.cfg.trainer.tags,
        )

method get_inference_client

get_inference_client() -> InferenceEngineInterface

Setup and return the inference engine client.

This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).

Returns:

NameTypeDescription
InferenceEngineInterfaceInferenceEngineInterfaceThe inference engine client.
Source code in skyrl/train/entrypoints/main_base.py:286-299
    def get_inference_client(self) -> InferenceEngineInterface:
        """Setup and return the inference engine client.

        This is a hook method that can be overridden by subclasses to customize
        inference engine creation (e.g., FlashRL, custom backends).

        Returns:
            InferenceEngineInterface: The inference engine client.
        """
        if _SKYRL_USE_NEW_INFERENCE:
            logger.info("Initializing new inference client")
            return self._get_new_inference_client()
        else:
            return self._get_legacy_inference_client()

method run

run()
Source code in skyrl/train/entrypoints/main_base.py:399-414
    def run(self):
        self.trainer = None
        try:
            trainer = self._setup_trainer()
            # Start the training loop
            asyncio.run(trainer.train())
        except Exception as e:
            # OOMs raised inside actor init (e.g. FSDPPolicyWorkerBase.init_model)
            # surface here as RayTaskError. Without this they only land in Ray
            # worker logs; route them through the tracker so wandb users see
            # them as an `error/tracebacks` table row.
            if self.trainer is not None and self.trainer.tracker is not None:
                self.trainer.tracker.log_exception(e, step=self.trainer.global_step)
            else:
                logger.error(f"Setup failed before tracker was initialized:\n{e}")
            raise

Evaluation Entrypoint

The evaluation-only entrypoint is the EvalOnlyEntrypoint class which runs evaluation without training.

class EvalOnlyEntrypoint

Bases: BasePPOExp

Functions:

NameDescription
get_train_datasetOverride to avoid requiring a train dataset for eval-only runs.
run
get_cfg_as_str
get_eval_datasetInitializes the evaluation dataset.
get_colocate_pgInitializes a placement group for colocated training.
get_generatorInitializes the generator.
get_trainerInitializes the trainer.
get_trackerInitializes the tracker for experiment tracking.
get_inference_clientSetup and return the inference engine client.

Attributes:

Source code in skyrl/train/entrypoints/main_generate.py:22-45
class EvalOnlyEntrypoint(BasePPOExp):
    def get_train_dataset(self):
        """Override to avoid requiring a train dataset for eval-only runs."""
        return None

    async def run(self, inference_engine_client: InferenceEngineInterface) -> dict[str, Any]:
        assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"

        await inference_engine_client.wake_up()
        generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)

        eval_fn = evaluate_step_wise if self.cfg.generator.step_wise_trajectories else evaluate
        results: dict[str, Any] = await eval_fn(
            eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
            generator=generator,
            cfg=self.cfg,
            global_step=None,
            tokenizer=self.tokenizer,
        )

        tracker = self.get_tracker()
        tracker.log(results, step=0, commit=True)

        return results

method get_train_dataset

get_train_dataset()

Override to avoid requiring a train dataset for eval-only runs.

Source code in skyrl/train/entrypoints/main_generate.py:23-25
    def get_train_dataset(self):
        """Override to avoid requiring a train dataset for eval-only runs."""
        return None

method run

run(inference_engine_client: InferenceEngineInterface) -> dict[str, Any]
Source code in skyrl/train/entrypoints/main_generate.py:27-45
    async def run(self, inference_engine_client: InferenceEngineInterface) -> dict[str, Any]:
        assert self.eval_dataset is not None, "The evaluation only entrypoint requires an eval dataset is provided"

        await inference_engine_client.wake_up()
        generator = self.get_generator(self.cfg, self.tokenizer, inference_engine_client)

        eval_fn = evaluate_step_wise if self.cfg.generator.step_wise_trajectories else evaluate
        results: dict[str, Any] = await eval_fn(
            eval_dataloader=build_dataloader(self.cfg, self.eval_dataset, is_train=False),
            generator=generator,
            cfg=self.cfg,
            global_step=None,
            tokenizer=self.tokenizer,
        )

        tracker = self.get_tracker()
        tracker.log(results, step=0, commit=True)

        return results

attr cfg

cfg = cfg

attr tokenizer

tokenizer = get_tokenizer(self.cfg.trainer.policy.model.path, trust_remote_code=True, use_fast=(not self.cfg.trainer.disable_fast_tokenizer), padding_side='left')

attr train_dataset

train_dataset = self.get_train_dataset()

attr eval_dataset

eval_dataset = self.get_eval_dataset()

attr colocate_pg

colocate_pg = self.get_colocate_pg()

method staticmethod get_cfg_as_str

get_cfg_as_str(cfg: SkyRLTrainConfig) -> str

method get_eval_dataset

get_eval_dataset()

Initializes the evaluation dataset.

Returns:

NameTypeDescription
PromptDatasetThe evaluation dataset.

method get_colocate_pg

get_colocate_pg(timeout: int = SKYRL_RAY_PG_TIMEOUT_IN_S) -> Optional[ResolvedPlacementGroup]

Initializes a placement group for colocated training.

Creates a single placement group with per-GPU bundles for all inference engines. The returned wrapper computes GPU-aware bundle ordering at init time.

Parameters:

NameTypeDescriptionDefault
timeoutintThe timeout for the placement group to be ready.SKYRL_RAY_PG_TIMEOUT_IN_S

Returns:

NameTypeDescription
ResolvedPlacementGroupOptional[ResolvedPlacementGroup]The placement group wrapper for colocated training, or None.

method get_generator

get_generator(cfg, tokenizer, inference_engine_client)

Initializes the generator.

Returns:

NameTypeDescription
GeneratorInterfaceThe generator.

method get_trainer

get_trainer(cfg, tracker, tokenizer, train_dataset, eval_dataset, inference_engine_client, generator: GeneratorInterface, colocate_pg: GeneratorInterface)

Initializes the trainer.

Returns:

NameTypeDescription
RayPPOTrainerThe trainer.

method get_tracker

get_tracker()

Initializes the tracker for experiment tracking.

Returns:

NameTypeDescription
TrackingThe tracker.

method get_inference_client

get_inference_client() -> InferenceEngineInterface

Setup and return the inference engine client.

This is a hook method that can be overridden by subclasses to customize inference engine creation (e.g., FlashRL, custom backends).

Returns:

NameTypeDescription
InferenceEngineInterfaceInferenceEngineInterfaceThe inference engine client.

On this page