SkyRL
API ReferenceSkyRLSkyRL-Train Backend

Trainer

Trainer API — RayPPOTrainer, Dispatch, Worker APIs.

Trainer Class

class RayPPOTrainer

RayPPOTrainer(cfg: SkyRLTrainConfig, tracker: Tracking, tokenizer: AutoTokenizer, train_dataset: Optional[PromptDataset], inference_engine_client: InferenceEngineClient, generator: GeneratorInterface, colocate_pg: Optional[PlacementGroup] = None, eval_dataset: Optional[PromptDataset] = None)

Functions:

NameDescription
evalRun generation and scoring on the evaluation dataset.
trainMain training loop for PPO
build_modelsInitialize the actors for training, and handle colocation logic
init_weight_sync_stateSetup the connection between policy model and inference engine for weight syncing.
sync_policy_weights_to_inference_enginesBroadcast policy weights to inference engines.
convert_to_training_inputConverts lists to a padded batch of tensors for training
generateGenerate rollouts.
postprocess_generator_outputConverts to per token rewards and computes pass@N.
compute_advantages_and_returnsCalculate advantages and returns for the data batch.
dump_dataDump data to pickle file
pad_batchPad the batch to be divisible by dp size
fwd_logprobs_values_rewardCalculate values from the critic, log probs from the policy and ref model.
apply_reward_kl_penaltyApplies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards.
train_critic_and_policyRun the training step for the policy and critic models.
handle_dynamic_samplingHandle dynamic sampling for the current batch.
save_checkpointsSave the model, optimizer, and training states to disk.
load_checkpointsLoad complete checkpoint state and return the global_step to resume from.
save_modelsSave the model parameters in HF format at cfg.trainer.export_path.
update_ref_with_policyUpdate the reference model with the policy model weights (required by some algorithms).

Attributes:

Source code in skyrl/train/trainer.py:82-1410
class RayPPOTrainer:
    def __init__(
        self,
        cfg: SkyRLTrainConfig,
        tracker: Tracking,
        tokenizer: AutoTokenizer,
        train_dataset: Optional[PromptDataset],
        inference_engine_client: InferenceEngineClient,
        generator: GeneratorInterface,
        colocate_pg: Optional[PlacementGroup] = None,
        eval_dataset: Optional[PromptDataset] = None,
    ):
        self.cfg = cfg
        self.colocate_all = cfg.trainer.placement.colocate_all
        self.tracker = tracker
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.inference_engine_client = inference_engine_client
        self.generator = generator
        self.train_dataloader = None
        self.total_training_steps = None
        self._build_train_dataloader_and_compute_training_steps()

        self.eval_dataloader = (
            build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else None
        )
        self.colocate_pg = colocate_pg

        self.resume_mode = ResumeMode(cfg.trainer.resume_mode)

        self.all_metrics = {}
        self.all_timings = {}
        self.global_step = 0

        # initialized in `build_models`
        self.policy_model: PPORayActorGroup = None
        self.critic_model: Optional[PPORayActorGroup] = None
        self.ref_model: Optional[PPORayActorGroup] = None
        # used for checkpoint cleanup
        self._node_ids: Optional[List[str]] = None

        self.dynamic_sampling_state: Optional[DynamicSamplingState] = None

        self.reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None
        self.dispatch: WorkerDispatch = None
        configure_ray_worker_logging()

    @property
    def has_critic(self) -> bool:
        """Check if critic model is configured."""
        return bool(self.cfg.trainer.critic.model.path)

    def _build_train_dataloader_and_compute_training_steps(self):
        """
        Hook for constructing the training dataloader. Subclasses can override
        this to customize dataloader behavior. For instance, fully async training
        needs a batch size of 1, among other features.
        Defaults to `trainer_utils.build_dataloader` with `is_train=True`.
        When train_dataset is None (e.g. Tinker backend provides data externally),
        the dataloader is not built.
        """
        if self.train_dataset is not None:
            self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True)
            self.total_training_steps = len(self.train_dataloader) * self.cfg.trainer.epochs

    @torch.no_grad()
    async def eval(self) -> Dict[str, float]:
        """
        Run generation and scoring on the evaluation dataset.

        The eval metrics are recorded after having finished training `self.global_step` steps.
        Metrics recorded in global_step 0 corresponds to evaluations before training.

        Returns:
            A dictionary of evaluation metrics.
        """
        if self.cfg.generator.step_wise_trajectories:
            eval_metrics = await evaluate_step_wise(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        else:
            eval_metrics = await evaluate(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        return eval_metrics

    async def train(self):
        """
        Main training loop for PPO
        """
        # Initialize weight sync state between policy model and inference engines.
        with Timer("init_weight_sync_state"):
            self.init_weight_sync_state()

        # Load checkpoint state if resumption is enabled.
        if self.resume_mode != ResumeMode.NONE:
            with Timer("load_checkpoints"):
                self.global_step, _ = self.load_checkpoints()

        # Prepare weights for sampling
        with Timer("sync_weights"):
            await self.dispatch.save_weights_for_sampler()

        # Eval before training
        if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
            with Timer("eval", self.all_timings):
                eval_metrics = await self.eval()
                self.tracker.log(eval_metrics, step=self.global_step, commit=True)

        # initialize kl controller
        if self.cfg.trainer.algorithm.use_kl_in_reward:
            self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)

        # main training loop
        pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
        start_epoch = self.global_step // len(self.train_dataloader)
        self.global_step += 1  # start training at global_step 1
        for epoch in range(start_epoch, self.cfg.trainer.epochs):
            for iter, rand_prompts in enumerate(self.train_dataloader):
                with Timer("step", self.all_timings):
                    # for colocate_all=true, inference engine is always on GPU when starting the training step

                    # 0. truncate data to have even shards
                    rand_prompts = self._remove_tail_data(rand_prompts)
                    generator_input, uids = prepare_generator_input(
                        rand_prompts,
                        self.cfg.generator.n_samples_per_prompt,
                        get_sampling_params_for_backend(
                            self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
                        ),
                        self.cfg.environment.env_class,
                        "train",
                        self.global_step,
                    )

                    # 1.1. generation phase
                    with Timer("generate", self.all_timings):
                        generator_output: GeneratorOutput = await self.generate(generator_input)

                    if self.cfg.generator.step_wise_trajectories:
                        # NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
                        # this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
                        uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]

                    # dynamic sampling
                    if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
                        generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
                        if keep_sampling:  # continue sampling
                            # update progress bar for current batch (but not global step)
                            pbar.update(1)
                            continue

                    if self.colocate_all:
                        # if we are not continuing sampling, we sleep the inference engine
                        await self.inference_engine_client.sleep()

                    # 1.2 postprocess rewards
                    with Timer("postprocess_generator_output", self.all_timings):
                        generator_output = self.postprocess_generator_output(generator_output, uids)

                    # 2. print example just for debugging
                    vis = self.tokenizer.decode(generator_output["response_ids"][0])
                    log_example(
                        logger,
                        prompt=generator_input["prompts"][0],
                        response=vis,
                        reward=generator_output["rewards"][0],
                    )

                    # 3. Convert GeneratorOutput to TrainingInputBatch
                    with Timer("convert_to_training_input", self.all_timings):
                        training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)
                        logger.info(f"Number of sequences: {len(training_input['sequences'])}")

                    # 4. Inference and calculate values, log probs, rewards, kl divergence
                    with Timer("fwd_logprobs_values_reward", self.all_timings):
                        training_input = self.fwd_logprobs_values_reward(training_input)

                    # 5. apply kl divergence penalty to rewards
                    if self.cfg.trainer.algorithm.use_kl_in_reward:
                        with Timer("apply_reward_kl_penalty", self.all_timings):
                            training_input = self.apply_reward_kl_penalty(training_input)

                    # 6. calculate advantages and returns
                    with Timer("compute_advantages_and_returns", self.all_timings):
                        training_input = self.compute_advantages_and_returns(training_input)
                        # remove some unwanted keys
                        for key in ["rewards"]:
                            training_input.pop(key)
                        training_input.metadata.pop("uids")

                        if self.cfg.trainer.algorithm.advantage_batch_normalize:
                            training_input = normalize_advantages_dict(training_input)

                    if self.cfg.trainer.dump_data_batch:
                        # dump data to file
                        with Timer("dump_data_batch"):
                            self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")

                    # 7. train policy/critic model
                    # Policy model is backloaded to GPU during training
                    with Timer("train_critic_and_policy", self.all_timings):
                        status = self.train_critic_and_policy(training_input)

                    # 8. conditionally save checkpoints and hf model
                    if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0:
                        with Timer("save_checkpoints", self.all_timings):
                            self.save_checkpoints()
                    if (
                        self.cfg.trainer.hf_save_interval > 0
                        and self.global_step % self.cfg.trainer.hf_save_interval == 0
                    ):
                        with Timer("save_hf_model", self.all_timings):
                            self.save_models()

                    # 9. conditionally sync policy and ref at the end of the epoch
                    if (
                        self.cfg.trainer.update_ref_every_epoch
                        and self.ref_model is not None
                        and iter == len(self.train_dataloader) - 1
                        and epoch != self.cfg.trainer.epochs - 1  # skip updating ref at the end of the last epoch
                    ):
                        with Timer("update_ref_with_policy", self.all_timings):
                            self.update_ref_with_policy()

                    # 10. Prepare weights for sampling
                    with Timer("sync_weights", self.all_timings):
                        await self.dispatch.save_weights_for_sampler()

                # 11. set logs
                logger.info(status)
                # log epoch info
                self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
                if self.cfg.trainer.eval_interval > 0 and (
                    self.global_step % self.cfg.trainer.eval_interval == 0
                    or self.global_step == self.total_training_steps
                ):
                    with Timer("eval", self.all_timings):
                        eval_metrics = await self.eval()
                        self.all_metrics.update(eval_metrics)

                log_payload = {
                    **self.all_metrics,
                    **{f"timing/{k}": v for k, v in self.all_timings.items()},
                }
                self.tracker.log(log_payload, step=self.global_step, commit=True)
                self.all_metrics = {}
                self.all_timings = {}

                # update progress bar after logging
                pbar.update(1)

                self.global_step += 1

                del training_input, generator_output

        pbar.close()
        if self.colocate_all:
            await self.inference_engine_client.sleep()
        if self.cfg.trainer.ckpt_interval > 0:
            with Timer("save_checkpoints", self.all_timings):
                self.save_checkpoints()
                logger.info("Saved final checkpoint.")
        if self.cfg.trainer.hf_save_interval > 0:
            with Timer("save_hf_model", self.all_timings):
                self.save_models()
                logger.info("Saved final model.")
        self.tracker.finish()
        logger.info("Training done!")

    def _remove_tail_data(self, entries: List[Any]) -> List[Any]:
        """Remove tail data to have even shards in terms of *effective* samples.

        Each prompt produces `n_samples_per_prompt` samples. For data-parallel
        training we care that the total number of samples is nicely splittable
        across the (combined) data-parallel size of all enabled models.
        """
        lcm_dp_size = self.dispatch.get_lcm_dp_size()

        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt

        # We want the largest m <= len(entries) such that:
        #   (m * n_samples_per_prompt) % lcm_dp_size == 0
        #
        # Let g = gcd(lcm_dp_size, n_samples_per_prompt). Then this is equivalent
        # to requiring m to be a multiple of (lcm_dp_size / g).
        stride = lcm_dp_size // math.gcd(lcm_dp_size, n_samples_per_prompt)
        if stride <= 1:
            # Every prompt count is valid, keep all entries.
            return entries

        kept_prompts = (len(entries) // stride) * stride
        return entries[:kept_prompts]

    def build_models(self, PolicyWorker, CriticWorker, RefWorker):
        """
        Initialize the actors for training, and handle colocation logic
        """
        cfg = self.cfg
        pg = None

        use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward

        if cfg.trainer.placement.colocate_all:
            num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
            num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
            num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
            ie_cfg = cfg.generator.inference_engine
            num_rollout_gpus = (
                ie_cfg.num_engines
                * ie_cfg.tensor_parallel_size
                * ie_cfg.pipeline_parallel_size
                * ie_cfg.data_parallel_size
            )
            assert (
                num_policy_gpus == num_rollout_gpus
            ), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
            pg = self.colocate_pg

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.2 if pg else 1,
                colocate_all=True,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
                record_memory=cfg.trainer.policy.record_memory,
            )
            if use_ref_model:
                assert (
                    num_policy_gpus == num_ref_gpus
                ), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2 if pg else 1,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                assert (
                    num_policy_gpus == num_critic_gpus
                ), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        else:
            if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
                assert (
                    cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
                    and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
                ), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."

                bundles = [
                    {
                        "GPU": cfg.trainer.placement.policy_num_gpus_per_node,
                        "CPU": cfg.trainer.placement.policy_num_gpus_per_node,
                    }
                    for _ in range(cfg.trainer.placement.policy_num_nodes)
                ]
                pg = placement_group(bundles, strategy="PACK")
                get_ray_pg_ready_with_timeout(pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.75 if pg else 1,
                colocate_all=False,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
            )
            if use_ref_model:
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.25 if pg else 1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    num_gpus_per_actor=1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        policy_steps_per_train_batch = (
            cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
        )
        critic_steps_per_train_batch = 0
        if cfg.trainer.critic.model.path:
            critic_steps_per_train_batch = (
                cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
            )
        policy_num_training_steps = (
            self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
        )
        critic_num_training_steps = (
            self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
        )
        if not cfg.trainer.placement.colocate_all:
            refs = []
            if ref_model is not None:
                refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
            refs.extend(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            if cfg.trainer.critic.model.path:
                refs.extend(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
            ray.get(refs)
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
        else:
            if ref_model is not None:
                ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
                ref_model.offload_to_cpu()
            ray.get(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
            policy_model.offload_to_cpu()
            if cfg.trainer.critic.model.path:
                ray.get(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
                critic_model.offload_to_cpu()

        self.policy_model: PPORayActorGroup = policy_model
        self.critic_model: Optional[PPORayActorGroup] = critic_model
        self.ref_model: Optional[PPORayActorGroup] = ref_model

        # Create unified dispatch that manages all actor groups
        self.dispatch = WorkerDispatch(
            cfg=self.cfg,
            policy_actor_group=policy_model,
            critic_actor_group=critic_model,
            ref_actor_group=ref_model,
            inference_engine_client=self.inference_engine_client,
        )

        # Mark all models as offloaded if colocate_all (they were offloaded above)
        if self.colocate_all:
            self.dispatch.mark_all_offloaded()

        logger.info("init policy/ref/critic models done")

    def init_weight_sync_state(self):
        """
        Setup the connection between policy model and inference engine for weight syncing.
        """
        self.dispatch.init_weight_sync_state(self.inference_engine_client)
        logger.info("Initialized weight sync state for policy model and inference engines.")

    def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:
        """Broadcast policy weights to inference engines.

        Note: For new code, prefer using dispatch.save_weights_for_sampler() which
        handles the full weight sync protocol including offload/backload.
        This method is kept for backward compatibility with subclasses.
        TODO(tgriggs): Remove this method when migration is complete.
        """
        return self.policy_model.async_run_ray_method(
            "pass_through",
            "broadcast_to_inference_engines",
            self.inference_engine_client,
            self.cfg.generator.inference_engine,
        )

    def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
        """Converts lists to a padded batch of tensors for training"""
        prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
        response_ids: List[List[int]] = generator_output["response_ids"]
        rewards: List[List[float]] = generator_output["rewards"]
        loss_masks: List[List[int]] = generator_output["loss_masks"]

        logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)

        (
            sequences_tensor,
            attention_masks_tensor,
            response_masks_tensor,
            rewards_tensor,
            loss_masks_tensor,
            rollout_logprobs_tensor,
        ) = convert_prompts_responses_to_batch_tensors(
            self.tokenizer,
            prompt_ids,
            response_ids,
            rewards,
            loss_masks,
            logprobs,
        )

        # sanity check for off_policy_correction
        off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
        tis_ratio_type = off_policy_correction.tis_ratio_type
        sequence_mask_metric = off_policy_correction.sequence_mask_metric
        if tis_ratio_type is not None or sequence_mask_metric is not None:
            assert (
                rollout_logprobs_tensor is not None
            ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
            assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

        training_input = TrainingInputBatch(
            {
                "sequences": sequences_tensor,  # Full trajectories (padded and concatenated prompts and responses)
                "attention_mask": attention_masks_tensor,
                "response_mask": response_masks_tensor,
                "rewards": rewards_tensor,
                "loss_mask": loss_masks_tensor,
                "rollout_logprobs": rollout_logprobs_tensor,
                "is_last_step": (
                    torch.tensor(generator_output["is_last_step"], dtype=torch.bool)
                    if generator_output.get("is_last_step", None) is not None
                    else None
                ),
            },
        )
        training_input.metadata = {"uids": uids}
        # padded response length
        training_input.metadata["response_length"] = response_masks_tensor.shape[1]
        if self.cfg.generator.step_wise_trajectories:
            assert (
                "trajectory_ids" in generator_output
            ), "Expected `trajectory_ids` in generator output for step wise training"
            training_input.metadata["trajectory_ids"] = [
                trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]
            ]
            training_input.metadata["avg_response_length"] = sum(
                len(sample_response_ids)
                for sample_response_ids, is_last_step in zip(response_ids, generator_output["is_last_step"])
                if is_last_step
            ) / len(response_ids)
        else:
            training_input.metadata["avg_response_length"] = sum(
                len(sample_response_ids) for sample_response_ids in response_ids
            ) / len(response_ids)

        logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
        training_input = self.pad_batch(training_input)
        logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")

        return training_input

    @torch.no_grad()
    async def generate(
        self,
        input_batch: GeneratorInput,
    ) -> GeneratorOutput:
        """
        Generate rollouts.

        If colocate_all is enabled:
        - before calling this method, the policy model should be on CPU and inference engine should
            be awake (i.e. on GPU).
        - after calling this method, the same model placement still holds.
        """
        # NOTE: we assume that .generate returns samples in the same order as passed in
        generator_output: GeneratorOutput = await self.generator.generate(input_batch)

        # add rollout metrics to self.all_metrics
        if generator_output["rollout_metrics"] is not None:
            self.all_metrics.update(generator_output["rollout_metrics"])

        if not self.cfg.generator.step_wise_trajectories:
            validate_generator_output(len(input_batch["prompts"]), generator_output)

        return generator_output

    @torch.no_grad()
    def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
        """
        Converts to per token rewards and computes pass@N.

        In the future algorithm specific reward or loss mask post processing should be done here.
        """
        generator_output_for_metrics = generator_output
        uids_for_metrics = uids
        if self.cfg.generator.step_wise_trajectories:
            generator_output_for_metrics = defaultdict(list)
            for key in generator_output:
                if isinstance(generator_output[key], list):
                    generator_output_for_metrics[key] = [
                        generator_output[key][i]
                        for i in range(len(generator_output[key]))
                        if generator_output["is_last_step"][i]
                    ]
            uids_for_metrics = [
                uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
            ]

        # only use `generator_output_for_metrics` for metrics calculation
        # For step-wise training, we only calculate metrics for the last step of each trajectory
        overall_metrics = get_metrics_from_generator_output(
            generator_output_for_metrics,
            uids_for_metrics,
        )

        # these use the full generator output
        rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
        responses: List[List[int]] = generator_output["response_ids"]
        per_token_rewards: List[List[float]] = []

        # Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
        if rewards and isinstance(rewards[0], list):
            # Token-level rewards: rewards is List[List[float]]
            per_token_rewards = rewards
        else:
            if self.cfg.trainer.algorithm.zero_variance_filter:
                kept_indices_set = set(zero_variance_filter(rewards, uids))
                generator_output["loss_masks"] = [
                    [0] * len(mask) if i not in kept_indices_set else mask
                    for i, mask in enumerate(generator_output["loss_masks"])
                ]
            # Response-level rewards: rewards is List[float], convert to per-token rewards
            for reward, response in zip(rewards, responses):
                per_token_reward = [0.0] * len(response)
                per_token_reward[-1] = float(reward)
                per_token_rewards.append(per_token_reward)

        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt

        reward_metrics = {
            f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
            "reward/avg_raw_reward": overall_metrics["avg_score"],
            "reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
        }
        self.all_metrics.update(reward_metrics)
        logger.info(
            f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
        )
        # re-assign reward but now it's per token rewards
        generator_output["rewards"] = per_token_rewards
        return generator_output

    @torch.no_grad()
    def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
        """Calculate advantages and returns for the data batch.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `.metadata["uids"]`: List[str]

        Adds:
            - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        token_level_rewards = data["rewards"]

        if self.cfg.generator.step_wise_trajectories:
            is_last_step = data["is_last_step"].bool()
            index = np.array(data.metadata["uids"])
            values = data["values"]
            # Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator
            # NOTE(Charlie): so we ignore per-step rewards in step-wise training.
            last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards[is_last_step],
                response_mask=data["response_mask"][is_last_step],
                index=index[is_last_step.cpu().numpy()],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                values=values[is_last_step] if values is not None else None,
                config=self.cfg.trainer.algorithm,
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
            # Broadcast each trajectory's advantage and return to all steps of each trajectory.
            traj_ids = (
                torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
            )
            num_groups = traj_ids[-1].item() + 1
            assert num_groups == len(
                last_step_advantages
            ), f"number of groups {num_groups} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
            advantages = last_step_advantages[traj_ids]
            returns = last_step_returns[traj_ids]
        else:
            advantages, returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards,
                response_mask=data["response_mask"],
                index=data.metadata["uids"],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                config=self.cfg.trainer.algorithm,
                values=data["values"],
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
        data["returns"] = returns
        data["advantages"] = advantages

        # remove padding while calculating metrics
        pad_size = data.metadata.get("pad_size", 0)
        num_samples = len(token_level_rewards)

        return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
        if self.cfg.generator.step_wise_trajectories:
            avg_rewards: float = return_sums[data["is_last_step"][: num_samples - pad_size]].mean().item()
        else:
            avg_rewards: float = return_sums.mean().item()

        avg_response_length = data.metadata["avg_response_length"]
        data = data.to("cpu")

        valid_advantages = torch.masked_select(
            data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
        )
        avg_advantages: float = valid_advantages.mean().item()
        avg_advantages_abs: float = valid_advantages.abs().mean().item()

        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}
        data.metadata["metrics"].update(
            {
                "avg_final_rewards": avg_rewards,
                "avg_response_length": avg_response_length,
                "avg_advantages": avg_advantages,
                "avg_advantages_abs": avg_advantages_abs,
            }
        )

        logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
        self.all_metrics.update(
            {
                "loss/avg_final_rewards": avg_rewards,
                "loss/avg_raw_advantages": avg_advantages,
                "loss/avg_raw_advantages_abs": avg_advantages_abs,
            }
        )
        return data

    def dump_data(self, data: TrainingInputBatch, file_name: str):
        """
        Dump data to pickle file
        """
        data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
        data_save_dir.mkdir(parents=True, exist_ok=True)
        data.save(data_save_dir / f"{file_name}.pkl")

    def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
        """Pad the batch to be divisible by dp size"""
        import math

        dp_size = self.dispatch.get_lcm_dp_size()
        pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
        new_tensors = {}
        training_input.metadata["pad_size"] = pad_size
        if pad_size == 0:
            return training_input
        for key, tensor in training_input.items():
            if tensor is not None:
                additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else ()

                if key == "is_last_step":
                    padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
                elif key == "loss_mask":
                    # ensures that padding tensors don't count towards the loss
                    padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
                else:
                    # ensures all padding tensors are in a valid format by cloning `pad_size` from the original input
                    # `pad_size` is guaranteed to be smaller than batch_size
                    padding_tensor = tensor[:pad_size].clone()
                new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)

        new_training_input = TrainingInputBatch(new_tensors)
        new_training_input.metadata = {}
        new_training_input.metadata["uids"] = training_input.metadata["uids"] + [f"pad{i}" for i in range(pad_size)]
        if "trajectory_ids" in training_input.metadata:
            new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [
                f"pad{i}" for i in range(pad_size)
            ]
        for key, value in training_input.metadata.items():
            if key not in ["uids", "trajectory_ids"]:
                new_training_input.metadata[key] = copy.deepcopy(value)
        return new_training_input

    @torch.no_grad()
    def fwd_logprobs_values_reward(
        self,
        training_input: TrainingInputBatch,
    ):
        """
        Calculate values from the critic, log probs from the policy and ref model.

        Dispatch handles offload/backload automatically for all colocation configurations.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `.metadata["response_length"]`: Int

        Adds:
            - `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        data_fwd_pass = training_input.select(keys=["sequences", "attention_mask"], metadata_keys=["response_length"])

        values = None
        base_log_probs = None
        action_log_probs = None

        # Critic forward (dispatch handles offload/backload automatically)
        if self.has_critic:
            critic_output = self.dispatch.forward("critic", data_fwd_pass)
            values = critic_output["output"]

        # Ref forward
        if self.ref_model is not None:
            ref_output = self.dispatch.forward("ref", data_fwd_pass)
            base_log_probs = ref_output["output"]
            self.dispatch.empty_cache("ref")

        # Policy forward
        policy_output = self.dispatch.forward("policy", data_fwd_pass)
        action_log_probs = policy_output["output"]

        # Empty cache after all forward passes
        self.dispatch.empty_cache()

        sequences_all: torch.Tensor = training_input["sequences"]
        # NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
        base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
        action_log_probs = action_log_probs[: len(sequences_all)]
        values = values[: len(sequences_all)] if values is not None else None

        training_input["base_action_log_probs"] = base_log_probs
        training_input["action_log_probs"] = action_log_probs
        training_input["values"] = values

        if training_input.get("rollout_logprobs", None) is not None:
            # calculates the difference in probs between inference and trainer components
            # only consider response tokens
            logprobs_diff = (
                training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
                - action_log_probs[training_input["loss_mask"] > 0]
            ).abs()

            logprobs_diff_mean = logprobs_diff.mean().item()
            logprobs_diff_std = logprobs_diff.std().item()
            self.all_metrics.update(
                {
                    "policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
                    "policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
                }
            )
        return training_input

    def apply_reward_kl_penalty(
        self,
        data: TrainingInputBatch,
    ) -> TrainingInputBatch:
        """Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
        loss_masks_all: torch.Tensor = data["loss_mask"]
        rewards: torch.Tensor = data["rewards"]
        base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
        action_log_probs: torch.Tensor = data["action_log_probs"]

        # single batched computation
        kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl(  # type: ignore
            action_log_probs,
            base_action_log_probs,
            loss_mask=loss_masks_all,
            kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
        )
        kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0]  # noqa: F821
        kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1)  # noqa: F821

        # NOTE (erictang000): only supporting custom rewards currently
        kl_loss_coef = (
            self.reward_kl_controller.value
            if self.reward_kl_controller is not None
            else self.cfg.trainer.algorithm.kl_loss_coef
        )
        rewards = rewards - kl * max(0, kl_loss_coef)
        data["rewards"] = rewards

        avg_kl: float = kl_mean.mean().item()
        avg_kl_max: float = kl_max.mean().item()

        # update the kl controller
        if self.reward_kl_controller is not None:
            self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0])  # n_steps is just the batch size
        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}

        data.metadata["metrics"].update(
            {
                "avg_kl": avg_kl,
                "avg_kl_max": avg_kl_max,
                "kl_loss_coef": kl_loss_coef,
            }
        )

        self.all_metrics.update(
            {
                "loss/avg_kl": avg_kl,
                "loss/avg_kl_max": avg_kl_max,
                "loss/kl_loss_coef": kl_loss_coef,
            }
        )

        return data

    def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]:
        """
        Execute training step for FSDP strategy using forward_backward + optim_step.

        The trainer loops over epochs and mini-batches. Workers handle micro-batching
        internally for gradient accumulation (memory efficiency).

        Uses staged data approach: the full batch is put in Ray object store once,
        and workers fetch + slice locally to avoid repeated serialization.

        Args:
            model: Model name ("policy" or "critic")
            data: Training data batch

        Returns:
            Dict of reduced metrics from training
        """
        # Compute mini batch size from config (algorithm-level concept)
        n_samples = self.cfg.generator.n_samples_per_prompt
        if model == "policy":
            mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples
        else:
            mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples

        all_metrics: Dict[str, List[float]] = defaultdict(list)

        # Stage full batch in object store ONCE to avoid repeated serialization
        data_ref = self.dispatch.stage_data(data)

        # Training loop over epochs and mini-batches
        for _epoch in range(self.cfg.trainer.update_epochs_per_batch):
            num_mini_batches = len(data) // mini_batch_size
            for local_step in range(num_mini_batches):
                start_idx = local_step * mini_batch_size
                end_idx = (local_step + 1) * mini_batch_size

                # Workers fetch from object store and slice locally
                status = self.dispatch.forward_backward_from_staged(model, data_ref, start_idx, end_idx)
                for k, v in status.items():
                    all_metrics[k].append(v)

                # Optimizer step after each mini batch
                grad_norm = self.dispatch.optim_step(model)
                if grad_norm is not None:
                    all_metrics["grad_norm"].append(grad_norm)

        # Reduce metrics across all mini-batches and epochs
        # pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it
        all_metrics.pop("loss_fn_outputs", None)
        reduced_metrics = reduce_metrics(all_metrics)
        return reduced_metrics

    def train_critic_and_policy(self, data: TrainingInputBatch):
        """
        Run the training step for the policy and critic models.

        Uses forward_backward + optim_step for both FSDP and Megatron strategies.
        """
        data.metadata["global_step"] = self.global_step
        critic_status = None

        # Unified training interface for both FSDP and Megatron
        if self.has_critic:
            with Timer("critic_train", self.all_timings):
                critic_status = self._execute_training_step("critic", data)
        with Timer("policy_train", self.all_timings):
            policy_status = self._execute_training_step("policy", data)

        # Update metrics
        if critic_status is not None:
            for k, v in critic_status.items():
                self.all_metrics.update({f"critic/{k}": v})

        for k, v in policy_status.items():
            self.all_metrics.update({f"policy/{k}": v})

        self.dispatch.empty_cache()

        return policy_status

    def handle_dynamic_sampling(
        self, generator_output: GeneratorOutput, uids: List[str]
    ) -> Tuple[GeneratorOutput, List[str], bool]:
        """
        Handle dynamic sampling for the current batch.

        Accumulates the generator output and UIDs across batches if we are sampling repeatedly
        and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
        If we hit the limit of max sample batches, we raise an error.

        Args:
            generator_output: Current batch generator output
            uids: Current batch UIDs

        Returns:
            processed_output: Filtered generator output
            processed_uids: Filtered UIDs
            keep_sampling: Whether to keep sampling
        """
        # Prepare sampling configuration
        max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
        dynamic_sampling_config = {
            "type": self.cfg.trainer.algorithm.dynamic_sampling.type,
            "max_sample_batches": max_sample_batches,
            "min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
            "train_batch_size": self.cfg.trainer.train_batch_size,
            "n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
        }

        if self.dynamic_sampling_state is None:
            self.dynamic_sampling_state: DynamicSamplingState = {
                "sample_batch_count": 1,
            }
        else:
            self.dynamic_sampling_state["sample_batch_count"] += 1

        # Handle dynamic sampling using utilities
        processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
            generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
        )

        # Check max resample limit, and if we hit it, raise an error
        if (
            keep_sampling
            and max_sample_batches > 0
            and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
        ):
            raise RuntimeError(
                f"Exiting training loop due to hitting dynamic sampling limit for "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
                f"Please check your data difficulty distribution."
            )
        # Update state
        self.dynamic_sampling_state = updated_state

        if not keep_sampling:
            # Reset state when sampling is complete
            self.dynamic_sampling_state = None

        return processed_output, processed_uids, keep_sampling

    def _get_dp_group_models(self, rank: int, model_type: str = ""):
        model = getattr(self, model_type)
        return model._actor_handlers[rank]

    def _get_mesh_rank(self, rank: int, model_type: str = "") -> MeshRank:
        model: PPORayActorGroup = getattr(self, model_type)
        actor_info: ActorInfo = model.actor_infos[rank]
        return actor_info.rank

    def save_checkpoints(self):
        """
        Save the model, optimizer, and training states to disk.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        # Create global step folder structure
        global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        critic_save_dir = os.path.join(global_step_folder, "critic")

        io.makedirs(global_step_folder, exist_ok=True)

        # Save policy checkpoint (dispatch handles offload/backload)
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save critic checkpoint (if it exists)
        if self.has_critic:
            self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)

        # Save dataloader state
        dataloader_save_path = os.path.join(global_step_folder, "data.pt")
        try:
            dataloader_state_dict = self.train_dataloader.state_dict()
            with io.open_file(dataloader_save_path, "wb") as f:
                torch.save(dataloader_state_dict, f)
            logger.info(f"Saved dataloader state to {dataloader_save_path}")
        except Exception as e:
            logger.warning(f"Failed to save dataloader state: {e}")

        # Save additional trainer state
        trainer_state = {
            "global_step": self.global_step,
            "config": asdict(self.cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking - write this last after all saves succeed
        latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_checkpoint_file, "w") as f:
            f.write(str(self.global_step))

        logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")

        # Clean up old checkpoints after successful save
        with Timer("cleanup_old_checkpoints", self.all_timings):
            self._cleanup_old_checkpoints()

    def _cleanup_old_checkpoints(self):
        if not self._node_ids:
            self._node_ids = self.dispatch.get_node_ids()
        run_on_each_node(
            self._node_ids,
            cleanup_old_checkpoints,
            self.cfg.trainer.ckpt_path,
            self.cfg.trainer.max_ckpts_to_keep,
        )
        # run on driver as well
        # NOTE (sumanthrh): the function will get called twice on the node with driver process, but it's ok because it's idempotent
        cleanup_old_checkpoints(self.cfg.trainer.ckpt_path, self.cfg.trainer.max_ckpts_to_keep)

    def load_checkpoints(self) -> Tuple[int, str]:
        """
        Load complete checkpoint state and return the global_step to resume from.
        Returns 0 if no checkpoint is loaded.

        If colocate_all is True, assumes that the policy model is currently on GPU.

        Returns:
            global_step: The global step to resume from.
            checkpoint_path: The path to the checkpoint.
        """
        checkpoint_path = None
        # Check if resumption is enabled
        if self.resume_mode == ResumeMode.NONE:
            logger.info("Checkpoint resumption disabled, starting training from scratch")
            return 0, None
        # first, let's get resume_path
        elif self.resume_mode == ResumeMode.LATEST:
            latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_checkpoint_file):
                logger.info("No checkpoint found, starting training from scratch")
                return 0, None
            with io.open_file(latest_checkpoint_file, "r") as f:
                ckpt_iteration = int(f.read().strip())
            checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
            # Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
            validate_consistency_for_latest_checkpoint(
                self.cfg.trainer.ckpt_path,
                ckpt_iteration,
                checkpoint_path,
                latest_checkpoint_file,
                self.cfg.trainer.ckpt_interval,
            )
        else:
            # Get and validate resume path
            checkpoint_path = Path(self.cfg.trainer.resume_path)
            if not checkpoint_path:
                raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")

            # Validate that it's a global_step directory
            if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
                raise ValueError(
                    f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
                )

        # Validate that the path exists
        if not io.exists(str(checkpoint_path)):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        logger.info(f"Loading checkpoint from: {checkpoint_path}")

        # Extract global step from checkpoint path
        global_step = extract_step_from_path(Path(checkpoint_path))
        if global_step == -1:
            raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
        logger.info(f"Resuming from global_step: {global_step}")

        # Define paths for different checkpoint components
        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        dataloader_state_path = os.path.join(checkpoint_path, "data.pt")

        # Validate that required checkpoint files exist
        if not io.exists(trainer_state_path):
            raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")

        # 1. Load and validate trainer state
        with io.open_file(trainer_state_path, "rb") as f:
            trainer_state = torch.load(f, map_location="cpu", weights_only=False)
        saved_global_step = trainer_state.get("global_step", global_step)
        logger.info("Successfully loaded trainer state")
        if saved_global_step != global_step:
            logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")

        # 2. Load dataloader state if available
        if io.exists(dataloader_state_path):
            try:
                with io.open_file(dataloader_state_path, "rb") as f:
                    dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
                self.train_dataloader.load_state_dict(dataloader_state)
                logger.info("Successfully loaded dataloader state")
            except Exception as e:
                logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
        else:
            logger.warning(
                f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
            )

        # 3. Load policy checkpoint (dispatch handles offload/backload)
        logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info("Successfully loaded policy checkpoint")

        # 4. Load critic checkpoint if it exists and we have a critic model
        if self.has_critic:
            logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
            self.dispatch.load_checkpoint(
                "critic",
                critic_ckpt_dir,
                load_optimizer_states=True,
                load_lr_scheduler_states=True,
            )
            logger.info("Successfully loaded critic checkpoint")

        logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
        return global_step, str(checkpoint_path)

    def save_models(self):
        """
        Save the model parameters in HF format at `cfg.trainer.export_path`.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        if self.has_critic:
            critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
            self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)

        logger.info("Successfully saved model weights.")

    def update_ref_with_policy(self):
        """
        Update the reference model with the policy model weights (required by some algorithms).

        Dispatch handles offload/backload automatically for all colocation configurations.
        After this method, save_weights_for_sampler() should be called to sync weights.
        """
        # TODO(tgriggs): Make policy-to-ref sync faster.
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")

        # Save policy model (dispatch handles GPU state)
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        # Re-initialize ref model from saved policy (dispatch handles offloading policy first)
        self.dispatch.init_model("ref", policy_export_dir)

        # Clean up temporary saved model files
        try:
            shutil.rmtree(policy_export_dir)
            logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
        except Exception as e:
            logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")

        logger.info("Successfully updated ref model with policy model, training continues.")

attr cfg

cfg = cfg

attr colocate_all

colocate_all = cfg.trainer.placement.colocate_all

attr tracker

tracker = tracker

attr tokenizer

tokenizer = tokenizer

attr train_dataset

train_dataset = train_dataset

attr eval_dataset

eval_dataset = eval_dataset

attr inference_engine_client

inference_engine_client = inference_engine_client

attr generator

generator = generator

attr train_dataloader

train_dataloader = None

attr total_training_steps

total_training_steps = None

attr eval_dataloader

eval_dataloader = build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else None

attr colocate_pg

colocate_pg = colocate_pg

attr resume_mode

resume_mode = ResumeMode(cfg.trainer.resume_mode)

attr all_metrics

all_metrics = {}

attr all_timings

all_timings = {}

attr global_step

global_step = 0

attr policy_model

policy_model: PPORayActorGroup = None

attr critic_model

critic_model: Optional[PPORayActorGroup] = None

attr ref_model

ref_model: Optional[PPORayActorGroup] = None

attr dynamic_sampling_state

dynamic_sampling_state: Optional[DynamicSamplingState] = None

attr reward_kl_controller

reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None

attr dispatch

dispatch: WorkerDispatch = None

attr property has_critic

has_critic: bool

Check if critic model is configured.

method async eval

eval() -> Dict[str, float]

Run generation and scoring on the evaluation dataset.

The eval metrics are recorded after having finished training self.global_step steps. Metrics recorded in global_step 0 corresponds to evaluations before training.

Returns:

TypeDescription
Dict[str, float]A dictionary of evaluation metrics.
Source code in skyrl/train/trainer.py:148-175
    @torch.no_grad()
    async def eval(self) -> Dict[str, float]:
        """
        Run generation and scoring on the evaluation dataset.

        The eval metrics are recorded after having finished training `self.global_step` steps.
        Metrics recorded in global_step 0 corresponds to evaluations before training.

        Returns:
            A dictionary of evaluation metrics.
        """
        if self.cfg.generator.step_wise_trajectories:
            eval_metrics = await evaluate_step_wise(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        else:
            eval_metrics = await evaluate(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        return eval_metrics

method async train

train()

Main training loop for PPO

Source code in skyrl/train/trainer.py:177-359
    async def train(self):
        """
        Main training loop for PPO
        """
        # Initialize weight sync state between policy model and inference engines.
        with Timer("init_weight_sync_state"):
            self.init_weight_sync_state()

        # Load checkpoint state if resumption is enabled.
        if self.resume_mode != ResumeMode.NONE:
            with Timer("load_checkpoints"):
                self.global_step, _ = self.load_checkpoints()

        # Prepare weights for sampling
        with Timer("sync_weights"):
            await self.dispatch.save_weights_for_sampler()

        # Eval before training
        if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
            with Timer("eval", self.all_timings):
                eval_metrics = await self.eval()
                self.tracker.log(eval_metrics, step=self.global_step, commit=True)

        # initialize kl controller
        if self.cfg.trainer.algorithm.use_kl_in_reward:
            self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)

        # main training loop
        pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
        start_epoch = self.global_step // len(self.train_dataloader)
        self.global_step += 1  # start training at global_step 1
        for epoch in range(start_epoch, self.cfg.trainer.epochs):
            for iter, rand_prompts in enumerate(self.train_dataloader):
                with Timer("step", self.all_timings):
                    # for colocate_all=true, inference engine is always on GPU when starting the training step

                    # 0. truncate data to have even shards
                    rand_prompts = self._remove_tail_data(rand_prompts)
                    generator_input, uids = prepare_generator_input(
                        rand_prompts,
                        self.cfg.generator.n_samples_per_prompt,
                        get_sampling_params_for_backend(
                            self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
                        ),
                        self.cfg.environment.env_class,
                        "train",
                        self.global_step,
                    )

                    # 1.1. generation phase
                    with Timer("generate", self.all_timings):
                        generator_output: GeneratorOutput = await self.generate(generator_input)

                    if self.cfg.generator.step_wise_trajectories:
                        # NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
                        # this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
                        uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]

                    # dynamic sampling
                    if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
                        generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
                        if keep_sampling:  # continue sampling
                            # update progress bar for current batch (but not global step)
                            pbar.update(1)
                            continue

                    if self.colocate_all:
                        # if we are not continuing sampling, we sleep the inference engine
                        await self.inference_engine_client.sleep()

                    # 1.2 postprocess rewards
                    with Timer("postprocess_generator_output", self.all_timings):
                        generator_output = self.postprocess_generator_output(generator_output, uids)

                    # 2. print example just for debugging
                    vis = self.tokenizer.decode(generator_output["response_ids"][0])
                    log_example(
                        logger,
                        prompt=generator_input["prompts"][0],
                        response=vis,
                        reward=generator_output["rewards"][0],
                    )

                    # 3. Convert GeneratorOutput to TrainingInputBatch
                    with Timer("convert_to_training_input", self.all_timings):
                        training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)
                        logger.info(f"Number of sequences: {len(training_input['sequences'])}")

                    # 4. Inference and calculate values, log probs, rewards, kl divergence
                    with Timer("fwd_logprobs_values_reward", self.all_timings):
                        training_input = self.fwd_logprobs_values_reward(training_input)

                    # 5. apply kl divergence penalty to rewards
                    if self.cfg.trainer.algorithm.use_kl_in_reward:
                        with Timer("apply_reward_kl_penalty", self.all_timings):
                            training_input = self.apply_reward_kl_penalty(training_input)

                    # 6. calculate advantages and returns
                    with Timer("compute_advantages_and_returns", self.all_timings):
                        training_input = self.compute_advantages_and_returns(training_input)
                        # remove some unwanted keys
                        for key in ["rewards"]:
                            training_input.pop(key)
                        training_input.metadata.pop("uids")

                        if self.cfg.trainer.algorithm.advantage_batch_normalize:
                            training_input = normalize_advantages_dict(training_input)

                    if self.cfg.trainer.dump_data_batch:
                        # dump data to file
                        with Timer("dump_data_batch"):
                            self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")

                    # 7. train policy/critic model
                    # Policy model is backloaded to GPU during training
                    with Timer("train_critic_and_policy", self.all_timings):
                        status = self.train_critic_and_policy(training_input)

                    # 8. conditionally save checkpoints and hf model
                    if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0:
                        with Timer("save_checkpoints", self.all_timings):
                            self.save_checkpoints()
                    if (
                        self.cfg.trainer.hf_save_interval > 0
                        and self.global_step % self.cfg.trainer.hf_save_interval == 0
                    ):
                        with Timer("save_hf_model", self.all_timings):
                            self.save_models()

                    # 9. conditionally sync policy and ref at the end of the epoch
                    if (
                        self.cfg.trainer.update_ref_every_epoch
                        and self.ref_model is not None
                        and iter == len(self.train_dataloader) - 1
                        and epoch != self.cfg.trainer.epochs - 1  # skip updating ref at the end of the last epoch
                    ):
                        with Timer("update_ref_with_policy", self.all_timings):
                            self.update_ref_with_policy()

                    # 10. Prepare weights for sampling
                    with Timer("sync_weights", self.all_timings):
                        await self.dispatch.save_weights_for_sampler()

                # 11. set logs
                logger.info(status)
                # log epoch info
                self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
                if self.cfg.trainer.eval_interval > 0 and (
                    self.global_step % self.cfg.trainer.eval_interval == 0
                    or self.global_step == self.total_training_steps
                ):
                    with Timer("eval", self.all_timings):
                        eval_metrics = await self.eval()
                        self.all_metrics.update(eval_metrics)

                log_payload = {
                    **self.all_metrics,
                    **{f"timing/{k}": v for k, v in self.all_timings.items()},
                }
                self.tracker.log(log_payload, step=self.global_step, commit=True)
                self.all_metrics = {}
                self.all_timings = {}

                # update progress bar after logging
                pbar.update(1)

                self.global_step += 1

                del training_input, generator_output

        pbar.close()
        if self.colocate_all:
            await self.inference_engine_client.sleep()
        if self.cfg.trainer.ckpt_interval > 0:
            with Timer("save_checkpoints", self.all_timings):
                self.save_checkpoints()
                logger.info("Saved final checkpoint.")
        if self.cfg.trainer.hf_save_interval > 0:
            with Timer("save_hf_model", self.all_timings):
                self.save_models()
                logger.info("Saved final model.")
        self.tracker.finish()
        logger.info("Training done!")

method build_models

build_models(PolicyWorker, CriticWorker, RefWorker)

Initialize the actors for training, and handle colocation logic

Source code in skyrl/train/trainer.py:385-580
    def build_models(self, PolicyWorker, CriticWorker, RefWorker):
        """
        Initialize the actors for training, and handle colocation logic
        """
        cfg = self.cfg
        pg = None

        use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward

        if cfg.trainer.placement.colocate_all:
            num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
            num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
            num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
            ie_cfg = cfg.generator.inference_engine
            num_rollout_gpus = (
                ie_cfg.num_engines
                * ie_cfg.tensor_parallel_size
                * ie_cfg.pipeline_parallel_size
                * ie_cfg.data_parallel_size
            )
            assert (
                num_policy_gpus == num_rollout_gpus
            ), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
            pg = self.colocate_pg

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.2 if pg else 1,
                colocate_all=True,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
                record_memory=cfg.trainer.policy.record_memory,
            )
            if use_ref_model:
                assert (
                    num_policy_gpus == num_ref_gpus
                ), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2 if pg else 1,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                assert (
                    num_policy_gpus == num_critic_gpus
                ), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        else:
            if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
                assert (
                    cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
                    and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
                ), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."

                bundles = [
                    {
                        "GPU": cfg.trainer.placement.policy_num_gpus_per_node,
                        "CPU": cfg.trainer.placement.policy_num_gpus_per_node,
                    }
                    for _ in range(cfg.trainer.placement.policy_num_nodes)
                ]
                pg = placement_group(bundles, strategy="PACK")
                get_ray_pg_ready_with_timeout(pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.75 if pg else 1,
                colocate_all=False,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
            )
            if use_ref_model:
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.25 if pg else 1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    num_gpus_per_actor=1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        policy_steps_per_train_batch = (
            cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
        )
        critic_steps_per_train_batch = 0
        if cfg.trainer.critic.model.path:
            critic_steps_per_train_batch = (
                cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
            )
        policy_num_training_steps = (
            self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
        )
        critic_num_training_steps = (
            self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
        )
        if not cfg.trainer.placement.colocate_all:
            refs = []
            if ref_model is not None:
                refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
            refs.extend(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            if cfg.trainer.critic.model.path:
                refs.extend(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
            ray.get(refs)
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
        else:
            if ref_model is not None:
                ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
                ref_model.offload_to_cpu()
            ray.get(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
            policy_model.offload_to_cpu()
            if cfg.trainer.critic.model.path:
                ray.get(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
                critic_model.offload_to_cpu()

        self.policy_model: PPORayActorGroup = policy_model
        self.critic_model: Optional[PPORayActorGroup] = critic_model
        self.ref_model: Optional[PPORayActorGroup] = ref_model

        # Create unified dispatch that manages all actor groups
        self.dispatch = WorkerDispatch(
            cfg=self.cfg,
            policy_actor_group=policy_model,
            critic_actor_group=critic_model,
            ref_actor_group=ref_model,
            inference_engine_client=self.inference_engine_client,
        )

        # Mark all models as offloaded if colocate_all (they were offloaded above)
        if self.colocate_all:
            self.dispatch.mark_all_offloaded()

        logger.info("init policy/ref/critic models done")

method init_weight_sync_state

init_weight_sync_state()

Setup the connection between policy model and inference engine for weight syncing.

Source code in skyrl/train/trainer.py:582-587
    def init_weight_sync_state(self):
        """
        Setup the connection between policy model and inference engine for weight syncing.
        """
        self.dispatch.init_weight_sync_state(self.inference_engine_client)
        logger.info("Initialized weight sync state for policy model and inference engines.")

method sync_policy_weights_to_inference_engines

sync_policy_weights_to_inference_engines() -> List[ObjectRef]

Broadcast policy weights to inference engines.

Note: For new code, prefer using dispatch.save_weights_for_sampler() which handles the full weight sync protocol including offload/backload. This method is kept for backward compatibility with subclasses. TODO(tgriggs): Remove this method when migration is complete.

Source code in skyrl/train/trainer.py:589-602
    def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:
        """Broadcast policy weights to inference engines.

        Note: For new code, prefer using dispatch.save_weights_for_sampler() which
        handles the full weight sync protocol including offload/backload.
        This method is kept for backward compatibility with subclasses.
        TODO(tgriggs): Remove this method when migration is complete.
        """
        return self.policy_model.async_run_ray_method(
            "pass_through",
            "broadcast_to_inference_engines",
            self.inference_engine_client,
            self.cfg.generator.inference_engine,
        )

method convert_to_training_input

convert_to_training_input(generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch

Converts lists to a padded batch of tensors for training

Source code in skyrl/train/trainer.py:604-678
    def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
        """Converts lists to a padded batch of tensors for training"""
        prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
        response_ids: List[List[int]] = generator_output["response_ids"]
        rewards: List[List[float]] = generator_output["rewards"]
        loss_masks: List[List[int]] = generator_output["loss_masks"]

        logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)

        (
            sequences_tensor,
            attention_masks_tensor,
            response_masks_tensor,
            rewards_tensor,
            loss_masks_tensor,
            rollout_logprobs_tensor,
        ) = convert_prompts_responses_to_batch_tensors(
            self.tokenizer,
            prompt_ids,
            response_ids,
            rewards,
            loss_masks,
            logprobs,
        )

        # sanity check for off_policy_correction
        off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
        tis_ratio_type = off_policy_correction.tis_ratio_type
        sequence_mask_metric = off_policy_correction.sequence_mask_metric
        if tis_ratio_type is not None or sequence_mask_metric is not None:
            assert (
                rollout_logprobs_tensor is not None
            ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
            assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

        training_input = TrainingInputBatch(
            {
                "sequences": sequences_tensor,  # Full trajectories (padded and concatenated prompts and responses)
                "attention_mask": attention_masks_tensor,
                "response_mask": response_masks_tensor,
                "rewards": rewards_tensor,
                "loss_mask": loss_masks_tensor,
                "rollout_logprobs": rollout_logprobs_tensor,
                "is_last_step": (
                    torch.tensor(generator_output["is_last_step"], dtype=torch.bool)
                    if generator_output.get("is_last_step", None) is not None
                    else None
                ),
            },
        )
        training_input.metadata = {"uids": uids}
        # padded response length
        training_input.metadata["response_length"] = response_masks_tensor.shape[1]
        if self.cfg.generator.step_wise_trajectories:
            assert (
                "trajectory_ids" in generator_output
            ), "Expected `trajectory_ids` in generator output for step wise training"
            training_input.metadata["trajectory_ids"] = [
                trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]
            ]
            training_input.metadata["avg_response_length"] = sum(
                len(sample_response_ids)
                for sample_response_ids, is_last_step in zip(response_ids, generator_output["is_last_step"])
                if is_last_step
            ) / len(response_ids)
        else:
            training_input.metadata["avg_response_length"] = sum(
                len(sample_response_ids) for sample_response_ids in response_ids
            ) / len(response_ids)

        logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
        training_input = self.pad_batch(training_input)
        logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")

        return training_input

method async generate

generate(input_batch: GeneratorInput) -> GeneratorOutput

Generate rollouts.

If colocate_all is enabled:

  • before calling this method, the policy model should be on CPU and inference engine should be awake (i.e. on GPU).
  • after calling this method, the same model placement still holds.
Source code in skyrl/train/trainer.py:680-703
    @torch.no_grad()
    async def generate(
        self,
        input_batch: GeneratorInput,
    ) -> GeneratorOutput:
        """
        Generate rollouts.

        If colocate_all is enabled:
        - before calling this method, the policy model should be on CPU and inference engine should
            be awake (i.e. on GPU).
        - after calling this method, the same model placement still holds.
        """
        # NOTE: we assume that .generate returns samples in the same order as passed in
        generator_output: GeneratorOutput = await self.generator.generate(input_batch)

        # add rollout metrics to self.all_metrics
        if generator_output["rollout_metrics"] is not None:
            self.all_metrics.update(generator_output["rollout_metrics"])

        if not self.cfg.generator.step_wise_trajectories:
            validate_generator_output(len(input_batch["prompts"]), generator_output)

        return generator_output

method postprocess_generator_output

postprocess_generator_output(generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput

Converts to per token rewards and computes pass@N.

In the future algorithm specific reward or loss mask post processing should be done here.

Source code in skyrl/train/trainer.py:705-769
    @torch.no_grad()
    def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
        """
        Converts to per token rewards and computes pass@N.

        In the future algorithm specific reward or loss mask post processing should be done here.
        """
        generator_output_for_metrics = generator_output
        uids_for_metrics = uids
        if self.cfg.generator.step_wise_trajectories:
            generator_output_for_metrics = defaultdict(list)
            for key in generator_output:
                if isinstance(generator_output[key], list):
                    generator_output_for_metrics[key] = [
                        generator_output[key][i]
                        for i in range(len(generator_output[key]))
                        if generator_output["is_last_step"][i]
                    ]
            uids_for_metrics = [
                uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
            ]

        # only use `generator_output_for_metrics` for metrics calculation
        # For step-wise training, we only calculate metrics for the last step of each trajectory
        overall_metrics = get_metrics_from_generator_output(
            generator_output_for_metrics,
            uids_for_metrics,
        )

        # these use the full generator output
        rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
        responses: List[List[int]] = generator_output["response_ids"]
        per_token_rewards: List[List[float]] = []

        # Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
        if rewards and isinstance(rewards[0], list):
            # Token-level rewards: rewards is List[List[float]]
            per_token_rewards = rewards
        else:
            if self.cfg.trainer.algorithm.zero_variance_filter:
                kept_indices_set = set(zero_variance_filter(rewards, uids))
                generator_output["loss_masks"] = [
                    [0] * len(mask) if i not in kept_indices_set else mask
                    for i, mask in enumerate(generator_output["loss_masks"])
                ]
            # Response-level rewards: rewards is List[float], convert to per-token rewards
            for reward, response in zip(rewards, responses):
                per_token_reward = [0.0] * len(response)
                per_token_reward[-1] = float(reward)
                per_token_rewards.append(per_token_reward)

        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt

        reward_metrics = {
            f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
            "reward/avg_raw_reward": overall_metrics["avg_score"],
            "reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
        }
        self.all_metrics.update(reward_metrics)
        logger.info(
            f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
        )
        # re-assign reward but now it's per token rewards
        generator_output["rewards"] = per_token_rewards
        return generator_output

method compute_advantages_and_returns

compute_advantages_and_returns(data: TrainingInputBatch) -> TrainingInputBatch

Calculate advantages and returns for the data batch.

Expects:

  • ["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["response_mask"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["loss_mask"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["values"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["rewards"]: Float[torch.Tensor, "batch_size seqlen"]
  • .metadata["uids"]: List[str]

Adds:

  • ["advantages"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["returns"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:771-869
    @torch.no_grad()
    def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
        """Calculate advantages and returns for the data batch.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `.metadata["uids"]`: List[str]

        Adds:
            - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        token_level_rewards = data["rewards"]

        if self.cfg.generator.step_wise_trajectories:
            is_last_step = data["is_last_step"].bool()
            index = np.array(data.metadata["uids"])
            values = data["values"]
            # Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator
            # NOTE(Charlie): so we ignore per-step rewards in step-wise training.
            last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards[is_last_step],
                response_mask=data["response_mask"][is_last_step],
                index=index[is_last_step.cpu().numpy()],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                values=values[is_last_step] if values is not None else None,
                config=self.cfg.trainer.algorithm,
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
            # Broadcast each trajectory's advantage and return to all steps of each trajectory.
            traj_ids = (
                torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
            )
            num_groups = traj_ids[-1].item() + 1
            assert num_groups == len(
                last_step_advantages
            ), f"number of groups {num_groups} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
            advantages = last_step_advantages[traj_ids]
            returns = last_step_returns[traj_ids]
        else:
            advantages, returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards,
                response_mask=data["response_mask"],
                index=data.metadata["uids"],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                config=self.cfg.trainer.algorithm,
                values=data["values"],
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
        data["returns"] = returns
        data["advantages"] = advantages

        # remove padding while calculating metrics
        pad_size = data.metadata.get("pad_size", 0)
        num_samples = len(token_level_rewards)

        return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
        if self.cfg.generator.step_wise_trajectories:
            avg_rewards: float = return_sums[data["is_last_step"][: num_samples - pad_size]].mean().item()
        else:
            avg_rewards: float = return_sums.mean().item()

        avg_response_length = data.metadata["avg_response_length"]
        data = data.to("cpu")

        valid_advantages = torch.masked_select(
            data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
        )
        avg_advantages: float = valid_advantages.mean().item()
        avg_advantages_abs: float = valid_advantages.abs().mean().item()

        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}
        data.metadata["metrics"].update(
            {
                "avg_final_rewards": avg_rewards,
                "avg_response_length": avg_response_length,
                "avg_advantages": avg_advantages,
                "avg_advantages_abs": avg_advantages_abs,
            }
        )

        logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
        self.all_metrics.update(
            {
                "loss/avg_final_rewards": avg_rewards,
                "loss/avg_raw_advantages": avg_advantages,
                "loss/avg_raw_advantages_abs": avg_advantages_abs,
            }
        )
        return data

method dump_data

dump_data(data: TrainingInputBatch, file_name: str)

Dump data to pickle file

Source code in skyrl/train/trainer.py:871-877
    def dump_data(self, data: TrainingInputBatch, file_name: str):
        """
        Dump data to pickle file
        """
        data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
        data_save_dir.mkdir(parents=True, exist_ok=True)
        data.save(data_save_dir / f"{file_name}.pkl")

method pad_batch

pad_batch(training_input: TrainingInputBatch) -> TrainingInputBatch

Pad the batch to be divisible by dp size

Source code in skyrl/train/trainer.py:879-914
    def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
        """Pad the batch to be divisible by dp size"""
        import math

        dp_size = self.dispatch.get_lcm_dp_size()
        pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
        new_tensors = {}
        training_input.metadata["pad_size"] = pad_size
        if pad_size == 0:
            return training_input
        for key, tensor in training_input.items():
            if tensor is not None:
                additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else ()

                if key == "is_last_step":
                    padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
                elif key == "loss_mask":
                    # ensures that padding tensors don't count towards the loss
                    padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
                else:
                    # ensures all padding tensors are in a valid format by cloning `pad_size` from the original input
                    # `pad_size` is guaranteed to be smaller than batch_size
                    padding_tensor = tensor[:pad_size].clone()
                new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)

        new_training_input = TrainingInputBatch(new_tensors)
        new_training_input.metadata = {}
        new_training_input.metadata["uids"] = training_input.metadata["uids"] + [f"pad{i}" for i in range(pad_size)]
        if "trajectory_ids" in training_input.metadata:
            new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [
                f"pad{i}" for i in range(pad_size)
            ]
        for key, value in training_input.metadata.items():
            if key not in ["uids", "trajectory_ids"]:
                new_training_input.metadata[key] = copy.deepcopy(value)
        return new_training_input

method fwd_logprobs_values_reward

fwd_logprobs_values_reward(training_input: TrainingInputBatch)

Calculate values from the critic, log probs from the policy and ref model.

Dispatch handles offload/backload automatically for all colocation configurations.

Expects:

  • ["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["attention_mask"]: Integer[torch.Tensor, "batch_size seqlen"]
  • .metadata["response_length"]: Int

Adds:

  • ["base_action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["values"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:916-986
    @torch.no_grad()
    def fwd_logprobs_values_reward(
        self,
        training_input: TrainingInputBatch,
    ):
        """
        Calculate values from the critic, log probs from the policy and ref model.

        Dispatch handles offload/backload automatically for all colocation configurations.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `.metadata["response_length"]`: Int

        Adds:
            - `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        data_fwd_pass = training_input.select(keys=["sequences", "attention_mask"], metadata_keys=["response_length"])

        values = None
        base_log_probs = None
        action_log_probs = None

        # Critic forward (dispatch handles offload/backload automatically)
        if self.has_critic:
            critic_output = self.dispatch.forward("critic", data_fwd_pass)
            values = critic_output["output"]

        # Ref forward
        if self.ref_model is not None:
            ref_output = self.dispatch.forward("ref", data_fwd_pass)
            base_log_probs = ref_output["output"]
            self.dispatch.empty_cache("ref")

        # Policy forward
        policy_output = self.dispatch.forward("policy", data_fwd_pass)
        action_log_probs = policy_output["output"]

        # Empty cache after all forward passes
        self.dispatch.empty_cache()

        sequences_all: torch.Tensor = training_input["sequences"]
        # NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
        base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
        action_log_probs = action_log_probs[: len(sequences_all)]
        values = values[: len(sequences_all)] if values is not None else None

        training_input["base_action_log_probs"] = base_log_probs
        training_input["action_log_probs"] = action_log_probs
        training_input["values"] = values

        if training_input.get("rollout_logprobs", None) is not None:
            # calculates the difference in probs between inference and trainer components
            # only consider response tokens
            logprobs_diff = (
                training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
                - action_log_probs[training_input["loss_mask"] > 0]
            ).abs()

            logprobs_diff_mean = logprobs_diff.mean().item()
            logprobs_diff_std = logprobs_diff.std().item()
            self.all_metrics.update(
                {
                    "policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
                    "policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
                }
            )
        return training_input

method apply_reward_kl_penalty

apply_reward_kl_penalty(data: TrainingInputBatch) -> TrainingInputBatch

Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards.

Source code in skyrl/train/trainer.py:988-1042
    def apply_reward_kl_penalty(
        self,
        data: TrainingInputBatch,
    ) -> TrainingInputBatch:
        """Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
        loss_masks_all: torch.Tensor = data["loss_mask"]
        rewards: torch.Tensor = data["rewards"]
        base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
        action_log_probs: torch.Tensor = data["action_log_probs"]

        # single batched computation
        kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl(  # type: ignore
            action_log_probs,
            base_action_log_probs,
            loss_mask=loss_masks_all,
            kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
        )
        kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0]  # noqa: F821
        kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1)  # noqa: F821

        # NOTE (erictang000): only supporting custom rewards currently
        kl_loss_coef = (
            self.reward_kl_controller.value
            if self.reward_kl_controller is not None
            else self.cfg.trainer.algorithm.kl_loss_coef
        )
        rewards = rewards - kl * max(0, kl_loss_coef)
        data["rewards"] = rewards

        avg_kl: float = kl_mean.mean().item()
        avg_kl_max: float = kl_max.mean().item()

        # update the kl controller
        if self.reward_kl_controller is not None:
            self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0])  # n_steps is just the batch size
        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}

        data.metadata["metrics"].update(
            {
                "avg_kl": avg_kl,
                "avg_kl_max": avg_kl_max,
                "kl_loss_coef": kl_loss_coef,
            }
        )

        self.all_metrics.update(
            {
                "loss/avg_kl": avg_kl,
                "loss/avg_kl_max": avg_kl_max,
                "loss/kl_loss_coef": kl_loss_coef,
            }
        )

        return data

method train_critic_and_policy

train_critic_and_policy(data: TrainingInputBatch)

Run the training step for the policy and critic models.

Uses forward_backward + optim_step for both FSDP and Megatron strategies.

Source code in skyrl/train/trainer.py:1096-1122
    def train_critic_and_policy(self, data: TrainingInputBatch):
        """
        Run the training step for the policy and critic models.

        Uses forward_backward + optim_step for both FSDP and Megatron strategies.
        """
        data.metadata["global_step"] = self.global_step
        critic_status = None

        # Unified training interface for both FSDP and Megatron
        if self.has_critic:
            with Timer("critic_train", self.all_timings):
                critic_status = self._execute_training_step("critic", data)
        with Timer("policy_train", self.all_timings):
            policy_status = self._execute_training_step("policy", data)

        # Update metrics
        if critic_status is not None:
            for k, v in critic_status.items():
                self.all_metrics.update({f"critic/{k}": v})

        for k, v in policy_status.items():
            self.all_metrics.update({f"policy/{k}": v})

        self.dispatch.empty_cache()

        return policy_status

method handle_dynamic_sampling

handle_dynamic_sampling(generator_output: GeneratorOutput, uids: List[str]) -> Tuple[GeneratorOutput, List[str], bool]

Handle dynamic sampling for the current batch.

Accumulates the generator output and UIDs across batches if we are sampling repeatedly and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch. If we hit the limit of max sample batches, we raise an error.

Parameters:

NameTypeDescriptionDefault
generator_outputGeneratorOutputCurrent batch generator outputrequired
uidsList[str]Current batch UIDsrequired

Returns:

NameTypeDescription
processed_outputGeneratorOutputFiltered generator output
processed_uidsList[str]Filtered UIDs
keep_samplingboolWhether to keep sampling
Source code in skyrl/train/trainer.py:1124-1184
    def handle_dynamic_sampling(
        self, generator_output: GeneratorOutput, uids: List[str]
    ) -> Tuple[GeneratorOutput, List[str], bool]:
        """
        Handle dynamic sampling for the current batch.

        Accumulates the generator output and UIDs across batches if we are sampling repeatedly
        and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
        If we hit the limit of max sample batches, we raise an error.

        Args:
            generator_output: Current batch generator output
            uids: Current batch UIDs

        Returns:
            processed_output: Filtered generator output
            processed_uids: Filtered UIDs
            keep_sampling: Whether to keep sampling
        """
        # Prepare sampling configuration
        max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
        dynamic_sampling_config = {
            "type": self.cfg.trainer.algorithm.dynamic_sampling.type,
            "max_sample_batches": max_sample_batches,
            "min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
            "train_batch_size": self.cfg.trainer.train_batch_size,
            "n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
        }

        if self.dynamic_sampling_state is None:
            self.dynamic_sampling_state: DynamicSamplingState = {
                "sample_batch_count": 1,
            }
        else:
            self.dynamic_sampling_state["sample_batch_count"] += 1

        # Handle dynamic sampling using utilities
        processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
            generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
        )

        # Check max resample limit, and if we hit it, raise an error
        if (
            keep_sampling
            and max_sample_batches > 0
            and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
        ):
            raise RuntimeError(
                f"Exiting training loop due to hitting dynamic sampling limit for "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
                f"Please check your data difficulty distribution."
            )
        # Update state
        self.dynamic_sampling_state = updated_state

        if not keep_sampling:
            # Reset state when sampling is complete
            self.dynamic_sampling_state = None

        return processed_output, processed_uids, keep_sampling

method save_checkpoints

save_checkpoints()

Save the model, optimizer, and training states to disk.

Dispatch handles offload/backload automatically for all colocation configurations.

Source code in skyrl/train/trainer.py:1195-1244
    def save_checkpoints(self):
        """
        Save the model, optimizer, and training states to disk.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        # Create global step folder structure
        global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        critic_save_dir = os.path.join(global_step_folder, "critic")

        io.makedirs(global_step_folder, exist_ok=True)

        # Save policy checkpoint (dispatch handles offload/backload)
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save critic checkpoint (if it exists)
        if self.has_critic:
            self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)

        # Save dataloader state
        dataloader_save_path = os.path.join(global_step_folder, "data.pt")
        try:
            dataloader_state_dict = self.train_dataloader.state_dict()
            with io.open_file(dataloader_save_path, "wb") as f:
                torch.save(dataloader_state_dict, f)
            logger.info(f"Saved dataloader state to {dataloader_save_path}")
        except Exception as e:
            logger.warning(f"Failed to save dataloader state: {e}")

        # Save additional trainer state
        trainer_state = {
            "global_step": self.global_step,
            "config": asdict(self.cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking - write this last after all saves succeed
        latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_checkpoint_file, "w") as f:
            f.write(str(self.global_step))

        logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")

        # Clean up old checkpoints after successful save
        with Timer("cleanup_old_checkpoints", self.all_timings):
            self._cleanup_old_checkpoints()

method load_checkpoints

load_checkpoints() -> Tuple[int, str]

Load complete checkpoint state and return the global_step to resume from. Returns 0 if no checkpoint is loaded.

If colocate_all is True, assumes that the policy model is currently on GPU.

Returns:

NameTypeDescription
global_stepintThe global step to resume from.
checkpoint_pathstrThe path to the checkpoint.
Source code in skyrl/train/trainer.py:1259-1370
    def load_checkpoints(self) -> Tuple[int, str]:
        """
        Load complete checkpoint state and return the global_step to resume from.
        Returns 0 if no checkpoint is loaded.

        If colocate_all is True, assumes that the policy model is currently on GPU.

        Returns:
            global_step: The global step to resume from.
            checkpoint_path: The path to the checkpoint.
        """
        checkpoint_path = None
        # Check if resumption is enabled
        if self.resume_mode == ResumeMode.NONE:
            logger.info("Checkpoint resumption disabled, starting training from scratch")
            return 0, None
        # first, let's get resume_path
        elif self.resume_mode == ResumeMode.LATEST:
            latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_checkpoint_file):
                logger.info("No checkpoint found, starting training from scratch")
                return 0, None
            with io.open_file(latest_checkpoint_file, "r") as f:
                ckpt_iteration = int(f.read().strip())
            checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
            # Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
            validate_consistency_for_latest_checkpoint(
                self.cfg.trainer.ckpt_path,
                ckpt_iteration,
                checkpoint_path,
                latest_checkpoint_file,
                self.cfg.trainer.ckpt_interval,
            )
        else:
            # Get and validate resume path
            checkpoint_path = Path(self.cfg.trainer.resume_path)
            if not checkpoint_path:
                raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")

            # Validate that it's a global_step directory
            if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
                raise ValueError(
                    f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
                )

        # Validate that the path exists
        if not io.exists(str(checkpoint_path)):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        logger.info(f"Loading checkpoint from: {checkpoint_path}")

        # Extract global step from checkpoint path
        global_step = extract_step_from_path(Path(checkpoint_path))
        if global_step == -1:
            raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
        logger.info(f"Resuming from global_step: {global_step}")

        # Define paths for different checkpoint components
        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        dataloader_state_path = os.path.join(checkpoint_path, "data.pt")

        # Validate that required checkpoint files exist
        if not io.exists(trainer_state_path):
            raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")

        # 1. Load and validate trainer state
        with io.open_file(trainer_state_path, "rb") as f:
            trainer_state = torch.load(f, map_location="cpu", weights_only=False)
        saved_global_step = trainer_state.get("global_step", global_step)
        logger.info("Successfully loaded trainer state")
        if saved_global_step != global_step:
            logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")

        # 2. Load dataloader state if available
        if io.exists(dataloader_state_path):
            try:
                with io.open_file(dataloader_state_path, "rb") as f:
                    dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
                self.train_dataloader.load_state_dict(dataloader_state)
                logger.info("Successfully loaded dataloader state")
            except Exception as e:
                logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
        else:
            logger.warning(
                f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
            )

        # 3. Load policy checkpoint (dispatch handles offload/backload)
        logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info("Successfully loaded policy checkpoint")

        # 4. Load critic checkpoint if it exists and we have a critic model
        if self.has_critic:
            logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
            self.dispatch.load_checkpoint(
                "critic",
                critic_ckpt_dir,
                load_optimizer_states=True,
                load_lr_scheduler_states=True,
            )
            logger.info("Successfully loaded critic checkpoint")

        logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
        return global_step, str(checkpoint_path)

method save_models

save_models()

Save the model parameters in HF format at cfg.trainer.export_path.

Dispatch handles offload/backload automatically for all colocation configurations.

Source code in skyrl/train/trainer.py:1372-1385
    def save_models(self):
        """
        Save the model parameters in HF format at `cfg.trainer.export_path`.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        if self.has_critic:
            critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
            self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)

        logger.info("Successfully saved model weights.")

method update_ref_with_policy

update_ref_with_policy()

Update the reference model with the policy model weights (required by some algorithms).

Dispatch handles offload/backload automatically for all colocation configurations. After this method, save_weights_for_sampler() should be called to sync weights.

Source code in skyrl/train/trainer.py:1387-1410
    def update_ref_with_policy(self):
        """
        Update the reference model with the policy model weights (required by some algorithms).

        Dispatch handles offload/backload automatically for all colocation configurations.
        After this method, save_weights_for_sampler() should be called to sync weights.
        """
        # TODO(tgriggs): Make policy-to-ref sync faster.
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")

        # Save policy model (dispatch handles GPU state)
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        # Re-initialize ref model from saved policy (dispatch handles offloading policy first)
        self.dispatch.init_model("ref", policy_export_dir)

        # Clean up temporary saved model files
        try:
            shutil.rmtree(policy_export_dir)
            logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
        except Exception as e:
            logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")

        logger.info("Successfully updated ref model with policy model, training continues.")

Dispatch APIs

class Dispatch

Bases: ABC

Base class for dispatch types

Dispatch types are responsible for:

  • dispatching method calls to actors handling data sharding if necessary
  • collecting results from actors and concatenating results if necessary
  • validating arguments for dispatch

Functions:

NameDescription
dispatchDispatches method calls to the actors with data sharding if necessary.
async_collectCollects results from the actors asynchronously in an asyncio-compatible way.
sync_collectCollects results from the actors synchronously and returns a TrainingOutputBatch.
validate_dispatch_argsValidate and process arguments for dispatch.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:61-98
class Dispatch(ABC):
    """Base class for dispatch types

    Dispatch types are responsible for:
    - dispatching method calls to actors handling data sharding if necessary
    - collecting results from actors and concatenating results if necessary
    - validating arguments for dispatch
    """

    @classmethod
    @abstractmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        """Dispatches method calls to the actors with data sharding if necessary."""
        pass

    @classmethod
    @abstractmethod
    async def async_collect(
        cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
    ) -> Optional[TrainingOutputBatch]:
        """Collects results from the actors asynchronously in an asyncio-compatible way."""
        pass

    @classmethod
    @abstractmethod
    def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
        """Collects results from the actors synchronously and returns a `TrainingOutputBatch`."""
        pass

    @classmethod
    @abstractmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        """Validate and process arguments for dispatch.

        Returns:
            Tuple of (args, kwargs) to be passed to dispatch
        """
        pass

attr dispatch

dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]

Dispatches method calls to the actors with data sharding if necessary.

Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:70-74
    @classmethod
    @abstractmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        """Dispatches method calls to the actors with data sharding if necessary."""
        pass

method abstractmethod async classmethod async_collect

async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]

Collects results from the actors asynchronously in an asyncio-compatible way.

Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:76-82
    @classmethod
    @abstractmethod
    async def async_collect(
        cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
    ) -> Optional[TrainingOutputBatch]:
        """Collects results from the actors asynchronously in an asyncio-compatible way."""
        pass

method abstractmethod classmethod sync_collect

sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]

Collects results from the actors synchronously and returns a TrainingOutputBatch.

Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:84-88
    @classmethod
    @abstractmethod
    def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
        """Collects results from the actors synchronously and returns a `TrainingOutputBatch`."""
        pass

method abstractmethod classmethod validate_dispatch_args

validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]

Validate and process arguments for dispatch.

Returns:

TypeDescription
Tuple[Tuple, Dict[str, Any]]Tuple of (args, kwargs) to be passed to dispatch
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:90-98
    @classmethod
    @abstractmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        """Validate and process arguments for dispatch.

        Returns:
            Tuple of (args, kwargs) to be passed to dispatch
        """
        pass

class MeshDispatch

Bases: Dispatch

Mesh dispatch type to dispatch data to a group of actors along the device mesh.

Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism. The actor method should accept a single argument - the data batch.

For data dispatch:

  • The input data is chunked into dp_size equal chunks, where dp_size is the size of data parallelism.
  • Each actor with the same DP rank processes the same data chunk in parallel.

For data collection:

  • Data is collected only from the primary rank of each model/sequence parallel group.
  • The primary rank is defined as the rank with (SP=0, TP=0, PP=0).
  • The collected chunks are concatenated in order of DP rank to reconstruct the full data.

Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:

  • Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk, and all actors with DP rank 1 process the second chunk.
  • Data collection: Only two actors contribute to the final output - the primary rank from each DP group: (DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order.

Functions:

NameDescription
dispatch
async_collect
sync_collect
dispatch_from_stagedDispatch to workers using pre-staged data from object store.
validate_dispatch_args
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:101-227
class MeshDispatch(Dispatch):
    """Mesh dispatch type to dispatch data to a group of actors along the device mesh.

    Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism.
    The actor method should accept a single argument - the data batch.

    For data dispatch:

    * The input data is chunked into `dp_size` equal chunks, where `dp_size` is the size of data parallelism.
    * Each actor with the same DP rank processes the same data chunk in parallel.

    For data collection:

    * Data is collected only from the primary rank of each model/sequence parallel group.
    * The primary rank is defined as the rank with (SP=0, TP=0, PP=0).
    * The collected chunks are concatenated in order of DP rank to reconstruct the full data.

    Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:

    * Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk,
      and all actors with DP rank 1 process the second chunk.
    * Data collection: Only two actors contribute to the final output - the primary rank from each DP group:
      (DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order.

    """

    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        dp_size = actor_infos[0].rank.dp_size
        assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
            len(data), dp_size
        )
        chunk_size = len(data) // dp_size
        data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)

        # Put each unique chunk in object store ONCE to avoid redundant serialization
        # when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
        chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]

        for actor_info in actor_infos:
            # Pass ObjectRef instead of data - workers will fetch from object store
            chunk_ref = chunk_refs[actor_info.rank.dp]
            object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
        return object_refs

    @classmethod
    async def async_collect(
        cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
    ) -> Optional[TrainingOutputBatch]:
        assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
        all_objects = await asyncio.gather(*object_refs)
        if len(all_objects) and all_objects[0] is not None:
            return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
        return

    @classmethod
    def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
        assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
        all_objects = ray.get(object_refs)
        if len(all_objects) and all_objects[0] is not None:
            return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
        # all should be none
        assert all(obj is None for obj in all_objects), "Got a mix of `None` and non-`None` objects"
        return

    @classmethod
    def dispatch_from_staged(
        cls, actor_infos: List[ActorInfo], method: str, data_ref: ObjectRef, start_idx: int, end_idx: int, **kwargs
    ) -> List[ObjectRef]:
        """
        Dispatch to workers using pre-staged data from object store.

        Workers receive the full batch ObjectRef and slice indices, then fetch
        and slice locally. This avoids serialization on each mini-batch iteration.

        Args:
            actor_infos: List of actor info objects
            method: Name of method to call on workers
            data_ref: ObjectRef to full TrainingInputBatch in object store
            start_idx: Start index for mini-batch slice (before DP chunking)
            end_idx: End index for mini-batch slice (before DP chunking)
            **kwargs: Additional keyword arguments to pass to the method

        Returns:
            List of ObjectRefs for worker results
        """
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        dp_size = actor_infos[0].rank.dp_size

        # Compute per-DP-rank slice indices
        mini_batch_size = end_idx - start_idx
        assert (
            mini_batch_size % dp_size == 0
        ), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}"
        chunk_size = mini_batch_size // dp_size

        for actor_info in actor_infos:
            dp_rank = actor_info.rank.dp
            # Compute this worker's slice of the mini-batch
            worker_start = start_idx + dp_rank * chunk_size
            worker_end = worker_start + chunk_size
            object_refs.append(
                getattr(actor_info.handle, method).remote(
                    data_ref, start_idx=worker_start, end_idx=worker_end, **kwargs
                )
            )
        return object_refs

    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # Extract data from either positional arg or kwarg
        if args:
            data = args[0]
            remaining_kwargs = kwargs
        elif "data" in kwargs:
            data = kwargs.pop("data")
            remaining_kwargs = kwargs
        else:
            raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")

        if not isinstance(data, TrainingInputBatch):
            raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
        # Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
        return (data,), remaining_kwargs

attr dispatch

dispatch(actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs: TrainingInputBatch) -> List[ObjectRef]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:127-146
    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        dp_size = actor_infos[0].rank.dp_size
        assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
            len(data), dp_size
        )
        chunk_size = len(data) // dp_size
        data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)

        # Put each unique chunk in object store ONCE to avoid redundant serialization
        # when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
        chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]

        for actor_info in actor_infos:
            # Pass ObjectRef instead of data - workers will fetch from object store
            chunk_ref = chunk_refs[actor_info.rank.dp]
            object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
        return object_refs

method abstractmethod async classmethod async_collect

async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:148-156
    @classmethod
    async def async_collect(
        cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
    ) -> Optional[TrainingOutputBatch]:
        assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
        all_objects = await asyncio.gather(*object_refs)
        if len(all_objects) and all_objects[0] is not None:
            return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
        return

method abstractmethod classmethod sync_collect

sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:158-166
    @classmethod
    def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
        assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
        all_objects = ray.get(object_refs)
        if len(all_objects) and all_objects[0] is not None:
            return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
        # all should be none
        assert all(obj is None for obj in all_objects), "Got a mix of `None` and non-`None` objects"
        return

method classmethod dispatch_from_staged

dispatch_from_staged(actor_infos: List[ActorInfo], method: str, data_ref: ObjectRef, start_idx: int, end_idx: int, **kwargs: int) -> List[ObjectRef]

Dispatch to workers using pre-staged data from object store.

Workers receive the full batch ObjectRef and slice indices, then fetch and slice locally. This avoids serialization on each mini-batch iteration.

Parameters:

NameTypeDescriptionDefault
actor_infosList[ActorInfo]List of actor info objectsrequired
methodstrName of method to call on workersrequired
data_refObjectRefObjectRef to full TrainingInputBatch in object storerequired
start_idxintStart index for mini-batch slice (before DP chunking)required
end_idxintEnd index for mini-batch slice (before DP chunking)required
**kwargsAdditional keyword arguments to pass to the method{}

Returns:

TypeDescription
List[ObjectRef]List of ObjectRefs for worker results
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:168-210
    @classmethod
    def dispatch_from_staged(
        cls, actor_infos: List[ActorInfo], method: str, data_ref: ObjectRef, start_idx: int, end_idx: int, **kwargs
    ) -> List[ObjectRef]:
        """
        Dispatch to workers using pre-staged data from object store.

        Workers receive the full batch ObjectRef and slice indices, then fetch
        and slice locally. This avoids serialization on each mini-batch iteration.

        Args:
            actor_infos: List of actor info objects
            method: Name of method to call on workers
            data_ref: ObjectRef to full TrainingInputBatch in object store
            start_idx: Start index for mini-batch slice (before DP chunking)
            end_idx: End index for mini-batch slice (before DP chunking)
            **kwargs: Additional keyword arguments to pass to the method

        Returns:
            List of ObjectRefs for worker results
        """
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        dp_size = actor_infos[0].rank.dp_size

        # Compute per-DP-rank slice indices
        mini_batch_size = end_idx - start_idx
        assert (
            mini_batch_size % dp_size == 0
        ), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}"
        chunk_size = mini_batch_size // dp_size

        for actor_info in actor_infos:
            dp_rank = actor_info.rank.dp
            # Compute this worker's slice of the mini-batch
            worker_start = start_idx + dp_rank * chunk_size
            worker_end = worker_start + chunk_size
            object_refs.append(
                getattr(actor_info.handle, method).remote(
                    data_ref, start_idx=worker_start, end_idx=worker_end, **kwargs
                )
            )
        return object_refs

method abstractmethod classmethod validate_dispatch_args

validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:212-227
    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # Extract data from either positional arg or kwarg
        if args:
            data = args[0]
            remaining_kwargs = kwargs
        elif "data" in kwargs:
            data = kwargs.pop("data")
            remaining_kwargs = kwargs
        else:
            raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")

        if not isinstance(data, TrainingInputBatch):
            raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
        # Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
        return (data,), remaining_kwargs

class PassThroughDispatch

Bases: Dispatch

PassThrough dispatch type to dispatch data to a group of actors without any sharding.

This is useful for cases where we want to run the same method on all the actors. Supports methods with any number of arguments.

Functions:

Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:230-265
class PassThroughDispatch(Dispatch):
    """PassThrough dispatch type to dispatch data to a group of actors without any sharding.

    This is useful for cases where we want to run the same method on all the actors.
    Supports methods with any number of arguments.
    """

    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]

    @classmethod
    async def async_collect(
        cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
    ) -> Optional[TrainingOutputBatch]:
        all_objects = await asyncio.gather(*object_refs)
        if len(all_objects) and all_objects[0] is not None:
            return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
        return

    @classmethod
    def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
        data_batches = ray.get(object_refs)
        if len(data_batches) > 0 and data_batches[0] is not None:
            assert isinstance(
                data_batches[0], TrainingOutputBatch
            ), "data_batches must be a list of `TrainingOutputBatch` objects"
            return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches)
        # all should be none
        assert all(obj is None for obj in data_batches), "Got a mix of `None` and non-`None` objects"
        return

    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # no validation needed just pass everything
        return args, kwargs

attr dispatch

dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:237-239
    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]

method abstractmethod async classmethod async_collect

async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:241-248
    @classmethod
    async def async_collect(
        cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
    ) -> Optional[TrainingOutputBatch]:
        all_objects = await asyncio.gather(*object_refs)
        if len(all_objects) and all_objects[0] is not None:
            return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
        return

method abstractmethod classmethod sync_collect

sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:250-260
    @classmethod
    def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
        data_batches = ray.get(object_refs)
        if len(data_batches) > 0 and data_batches[0] is not None:
            assert isinstance(
                data_batches[0], TrainingOutputBatch
            ), "data_batches must be a list of `TrainingOutputBatch` objects"
            return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches)
        # all should be none
        assert all(obj is None for obj in data_batches), "Got a mix of `None` and non-`None` objects"
        return

method abstractmethod classmethod validate_dispatch_args

validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:262-265
    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # no validation needed just pass everything
        return args, kwargs

Worker APIs

The base worker abstraction in SkyRL.

class Worker

Worker(cfg: TrainerConfig, *args: TrainerConfig, **kwargs: TrainerConfig)

Bases: DistributedTorchRayActor

Functions:

NameDescription
get_node_local_rank
init_worker_process_group
get_mesh_rank
get_gpu_id
get_ray_node_id
get_master_addr_port
init_modelInitialize worker state (model, and optimizer if applicable) on worker.
empty_cacheEmpty GPU memory cache on Worker's CUDA device
offload_to_cpuOffload all worker state to CPU.
backload_to_gpuBackload worker state to GPU
get_cuda_memoryGet CUDA memory usage on worker's CUDA device.
save_memory_snapshotSave a snapshot of memory usage on the Worker's CUDA device.
init_weight_sync_stateInitialize state for weight syncing with Inference Engine Client
forwardRun forward pass on the input batch in inference mode.

Attributes:

Source code in skyrl/backends/skyrl_train/workers/worker.py:231-395
class Worker(DistributedTorchRayActor):
    def __init__(self, cfg: TrainerConfig, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cfg = cfg
        self._transfer_strategy_cls = None  # Set in init_weight_transfer_communicator

        if self.cfg.algorithm.temperature is None:
            raise ValueError("`cfg.algorithm.temperature` must be set")

    def init_model(self, *args, **kwargs):
        """Initialize worker state (model, and optimizer if applicable) on worker."""
        raise NotImplementedError()

    def empty_cache(self) -> None:
        """Empty GPU memory cache on Worker's CUDA device"""
        torch.cuda.empty_cache()

    def offload_to_cpu(self, pin_memory=True, non_blocking=True):
        """Offload all worker state to CPU.

        After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain

        Args:
            pin_memory: Whether to use pinned/ paged-locked memory on CPU
            non_blocking: Whether the operation is non-blocking
        """
        raise NotImplementedError()

    def backload_to_gpu(self, non_blocking=True):
        """Backload worker state to GPU

        Args:
            non_blocking: Whether the operation is non-blocking
        """
        raise NotImplementedError()

    def get_cuda_memory(self) -> Dict[str, Any]:
        """Get CUDA memory usage on worker's CUDA device."""
        torch.cuda.synchronize()
        free, total = torch.cuda.mem_get_info()
        return {
            "allocated": torch.cuda.memory_allocated(),
            "reserved": torch.cuda.memory_reserved(),
            "free": free,
            "total": total,
        }

    def save_memory_snapshot(self, tag: str = ""):
        """Save a snapshot of memory usage on the Worker's CUDA device.

        No-ops if record_memory is False.

        Args:
            tag: Label for the snapshot (e.g., "forward_backward", "optim_step")

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        if not self.record_memory:
            return

        # Track snapshot count for unique filenames
        if not hasattr(self, "_snapshot_count"):
            self._snapshot_count = 0
        self._snapshot_count += 1

        rank = torch.distributed.get_rank()
        save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
        if self._local_rank == 0 and not io.exists(save_path):
            io.makedirs(save_path, exist_ok=True)
        torch.distributed.barrier()

        tag_str = f"_{tag}" if tag else ""
        file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
        record_memory_path = os.path.join(save_path, file_name)
        if io.exists(record_memory_path):
            # seeing issues if we don't remove the file first
            io.remove(record_memory_path)
        torch.cuda.memory._dump_snapshot(record_memory_path)

    async def init_weight_sync_state(
        self,
        inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
        inference_engine_cfg: "InferenceEngineConfig",
    ):
        """Initialize state for weight syncing with Inference Engine Client

        Creates init info and sender, then sends init info to inference engines
        so they can create receivers.

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls

        assert inference_engine_client is not None

        # Determine transfer strategy based on inference engine config and placement
        self._transfer_strategy_cls = get_transfer_strategy_cls(
            weight_sync_backend=inference_engine_cfg.weight_sync_backend,
            colocate_all=self.cfg.placement.colocate_all,
        )

        # For new inference path, fetch world_size from servers
        # For legacy path, calculate from config
        inference_world_size = None
        if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
            inference_world_size, _ = await inference_engine_client.get_world_size()

        # Create init info on all ranks (it's deterministic from cfg or fetched world_size)
        init_info = self._transfer_strategy_cls.create_init_info(
            inference_engine_cfg, inference_world_size=inference_world_size
        )

        # Create sender on all ranks
        # Strategy implementations may have different logic for different ranks
        tasks = [
            asyncio.to_thread(
                self._transfer_strategy_cls.create_sender,
                init_info=init_info,
                inference_client=inference_engine_client,
            ),
        ]

        # Only rank 0 initializes receivers on inference engines
        # NOTE: For broadcast strategy, sender and receiver init must run concurrently
        # because both need to join the same process group to avoid deadlock
        if torch.distributed.get_rank() == 0:
            tasks.append(inference_engine_client.init_weight_update_communicator(init_info))

        results = await asyncio.gather(*tasks)
        self._weight_transfer_sender = results[0]  # sender is always first task

        # # Register signal handlers for termination only on rank 0
        # NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
        # The better way is to just have this specified in __del__, but there is
        # no guarattee that __del__ will be called in general. Ray also doesn't
        # explictly call __del__ when the actor shuts down.
        # It's commented out so that we can fix this in the future.
        # atexit.register(self._handle_termination)

        torch.distributed.barrier()

    def forward(
        self,
        data: TrainingInputBatch,
    ) -> TrainingOutputBatch:
        """Run forward pass on the input batch in inference mode.

        This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.micro_forward_batch_size_per_gpu`.
        """
        # run in micro batches of cfg.micro_forward_batch_size_per_gpu
        # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
        micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)

        outputs = []
        for micro_batch in micro_batches:
            outputs.append(self._forward_micro_batch(micro_batch))
        output = TrainingOutputBatch.cat(outputs)
        if output.device is not None and output.device != torch.device("cpu"):
            output = output.to("cpu")
        return output

    def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutputBatch:
        raise NotImplementedError()

attr sequence_parallel_size

sequence_parallel_size: int = sequence_parallel_size

attr record_memory

record_memory = record_memory

get_node_local_rank

get_node_local_rank()

init_worker_process_group

init_worker_process_group()

get_mesh_rank

get_mesh_rank()

get_gpu_id

get_gpu_id()

get_ray_node_id

get_ray_node_id()

get_master_addr_port

get_master_addr_port()

attr cfg

cfg = cfg

method init_model

init_model(*args, **kwargs)

Initialize worker state (model, and optimizer if applicable) on worker.

Source code in skyrl/backends/skyrl_train/workers/worker.py:240-242
    def init_model(self, *args, **kwargs):
        """Initialize worker state (model, and optimizer if applicable) on worker."""
        raise NotImplementedError()

method empty_cache

empty_cache() -> None

Empty GPU memory cache on Worker's CUDA device

Source code in skyrl/backends/skyrl_train/workers/worker.py:244-246
    def empty_cache(self) -> None:
        """Empty GPU memory cache on Worker's CUDA device"""
        torch.cuda.empty_cache()

method offload_to_cpu

offload_to_cpu(pin_memory = True, non_blocking = True)

Offload all worker state to CPU.

After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain

Parameters:

NameTypeDescriptionDefault
pin_memoryWhether to use pinned/ paged-locked memory on CPUTrue
non_blockingWhether the operation is non-blockingTrue
Source code in skyrl/backends/skyrl_train/workers/worker.py:248-257
    def offload_to_cpu(self, pin_memory=True, non_blocking=True):
        """Offload all worker state to CPU.

        After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain

        Args:
            pin_memory: Whether to use pinned/ paged-locked memory on CPU
            non_blocking: Whether the operation is non-blocking
        """
        raise NotImplementedError()

method backload_to_gpu

backload_to_gpu(non_blocking = True)

Backload worker state to GPU

Parameters:

NameTypeDescriptionDefault
non_blockingWhether the operation is non-blockingTrue
Source code in skyrl/backends/skyrl_train/workers/worker.py:259-265
    def backload_to_gpu(self, non_blocking=True):
        """Backload worker state to GPU

        Args:
            non_blocking: Whether the operation is non-blocking
        """
        raise NotImplementedError()

method get_cuda_memory

get_cuda_memory() -> Dict[str, Any]

Get CUDA memory usage on worker's CUDA device.

Source code in skyrl/backends/skyrl_train/workers/worker.py:267-276
    def get_cuda_memory(self) -> Dict[str, Any]:
        """Get CUDA memory usage on worker's CUDA device."""
        torch.cuda.synchronize()
        free, total = torch.cuda.mem_get_info()
        return {
            "allocated": torch.cuda.memory_allocated(),
            "reserved": torch.cuda.memory_reserved(),
            "free": free,
            "total": total,
        }

method save_memory_snapshot

save_memory_snapshot(tag: str = '')

Save a snapshot of memory usage on the Worker's CUDA device.

No-ops if record_memory is False.

Parameters:

NameTypeDescriptionDefault
tagstrLabel for the snapshot (e.g., "forward_backward", "optim_step")''

.. note:: This function should be called on all the ranks in the worker group simultaneously.

Source code in skyrl/backends/skyrl_train/workers/worker.py:278-309
    def save_memory_snapshot(self, tag: str = ""):
        """Save a snapshot of memory usage on the Worker's CUDA device.

        No-ops if record_memory is False.

        Args:
            tag: Label for the snapshot (e.g., "forward_backward", "optim_step")

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        if not self.record_memory:
            return

        # Track snapshot count for unique filenames
        if not hasattr(self, "_snapshot_count"):
            self._snapshot_count = 0
        self._snapshot_count += 1

        rank = torch.distributed.get_rank()
        save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
        if self._local_rank == 0 and not io.exists(save_path):
            io.makedirs(save_path, exist_ok=True)
        torch.distributed.barrier()

        tag_str = f"_{tag}" if tag else ""
        file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
        record_memory_path = os.path.join(save_path, file_name)
        if io.exists(record_memory_path):
            # seeing issues if we don't remove the file first
            io.remove(record_memory_path)
        torch.cuda.memory._dump_snapshot(record_memory_path)

method init_weight_sync_state

init_weight_sync_state(inference_engine_client: Union[InferenceEngineClient, RemoteInferenceClient], inference_engine_cfg: InferenceEngineConfig)

Initialize state for weight syncing with Inference Engine Client

Creates init info and sender, then sends init info to inference engines so they can create receivers.

.. note:: This function should be called on all the ranks in the worker group simultaneously.

Source code in skyrl/backends/skyrl_train/workers/worker.py:311-372
    async def init_weight_sync_state(
        self,
        inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
        inference_engine_cfg: "InferenceEngineConfig",
    ):
        """Initialize state for weight syncing with Inference Engine Client

        Creates init info and sender, then sends init info to inference engines
        so they can create receivers.

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls

        assert inference_engine_client is not None

        # Determine transfer strategy based on inference engine config and placement
        self._transfer_strategy_cls = get_transfer_strategy_cls(
            weight_sync_backend=inference_engine_cfg.weight_sync_backend,
            colocate_all=self.cfg.placement.colocate_all,
        )

        # For new inference path, fetch world_size from servers
        # For legacy path, calculate from config
        inference_world_size = None
        if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
            inference_world_size, _ = await inference_engine_client.get_world_size()

        # Create init info on all ranks (it's deterministic from cfg or fetched world_size)
        init_info = self._transfer_strategy_cls.create_init_info(
            inference_engine_cfg, inference_world_size=inference_world_size
        )

        # Create sender on all ranks
        # Strategy implementations may have different logic for different ranks
        tasks = [
            asyncio.to_thread(
                self._transfer_strategy_cls.create_sender,
                init_info=init_info,
                inference_client=inference_engine_client,
            ),
        ]

        # Only rank 0 initializes receivers on inference engines
        # NOTE: For broadcast strategy, sender and receiver init must run concurrently
        # because both need to join the same process group to avoid deadlock
        if torch.distributed.get_rank() == 0:
            tasks.append(inference_engine_client.init_weight_update_communicator(init_info))

        results = await asyncio.gather(*tasks)
        self._weight_transfer_sender = results[0]  # sender is always first task

        # # Register signal handlers for termination only on rank 0
        # NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
        # The better way is to just have this specified in __del__, but there is
        # no guarattee that __del__ will be called in general. Ray also doesn't
        # explictly call __del__ when the actor shuts down.
        # It's commented out so that we can fix this in the future.
        # atexit.register(self._handle_termination)

        torch.distributed.barrier()

method abstractmethod forward

forward(data: TrainingInputBatch) -> TrainingOutputBatch

Run forward pass on the input batch in inference mode.

This is a wrapper around _forward_micro_batch that runs in micro batches of cfg.micro_forward_batch_size_per_gpu.

Source code in skyrl/backends/skyrl_train/workers/worker.py:374-392
    def forward(
        self,
        data: TrainingInputBatch,
    ) -> TrainingOutputBatch:
        """Run forward pass on the input batch in inference mode.

        This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.micro_forward_batch_size_per_gpu`.
        """
        # run in micro batches of cfg.micro_forward_batch_size_per_gpu
        # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
        micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)

        outputs = []
        for micro_batch in micro_batches:
            outputs.append(self._forward_micro_batch(micro_batch))
        output = TrainingOutputBatch.cat(outputs)
        if output.device is not None and output.device != torch.device("cpu"):
            output = output.to("cpu")
        return output

class PPORayActorGroup

PPORayActorGroup(cfg: TrainerConfig, num_nodes: TrainerConfig, num_gpus_per_node: TrainerConfig, ray_actor_type: Type[Worker], pg: Optional[PlacementGroup] = None, num_gpus_per_actor: float = 1.0, resources: Optional[Dict[str, float]] = None, num_resources_per_node: Optional[int] = None, colocate_all: bool = False, sequence_parallel_size: int = 1, record_memory: bool = False) -> None

A group of ray actors Functions start with 'async' should return list of object refs

Parameters:

NameTypeDescriptionDefault
cfgTrainerConfigconfig object for workersrequired
num_nodesintNumber of nodes for this actor group.required
num_gpus_per_nodeintNumber of gpus for this actor group.required
ray_actor_typeType[Worker]PPO model type that this actor group serve on.required
pgOptional[PlacementGroup]Placement group to schedule actor on. If none, create new placement group automatically. Defaults to None.None
num_gpus_per_actorfloatNumber of gpus allocated for each actor. If < 1.0, multiple models can share same gpu. Defaults to 1.1.0

Functions:

NameDescription
async_init_modelAsynchronously initialize worker state (model, and optimizer if applicable) from model path
offload_to_cpuOffload all worker state to CPU.
backload_to_gpuBackload worker state to GPU
run_methodRun a method on all actors using specified dispatch type synchronously.
async_run_ray_methodRun a method on all actors using specified dispatch type asynchronously.
async_run_methodRun a method on all actors using specified dispatch type in an asyncio-compatible way.

Attributes:

Parameters:

NameTypeDescriptionDefault
pgOptional[PlacementGroup]Placement group for the worker group. Accepts a single PlacementGroup, or None. Note that if colocate_all is True, the number of bundles in the placement group must match world_size.None
Source code in skyrl/backends/skyrl_train/workers/worker.py:399-663
class PPORayActorGroup:
    """
    A group of ray actors
    Functions start with 'async' should return list of object refs

    Args:
        cfg: config object for workers
        num_nodes (int): Number of nodes for this actor group.
        num_gpus_per_node (int): Number of gpus for this actor group.
        ray_actor_type (Type[Worker]): PPO model type that this actor group serve on.
        pg (Optional[PlacementGroup]): Placement group to schedule actor on.
            If none, create new placement group automatically. Defaults to None.
        num_gpus_per_actor (float, optional): Number of gpus allocated for each actor.
            If < 1.0, multiple models can share same gpu. Defaults to 1.
    """

    def __init__(
        self,
        cfg: TrainerConfig,
        num_nodes,
        num_gpus_per_node,
        ray_actor_type: Type[Worker],
        pg: Optional[PlacementGroup] = None,
        num_gpus_per_actor: float = 1.0,
        resources: Optional[Dict[str, float]] = None,
        num_resources_per_node: Optional[int] = None,
        colocate_all: bool = False,
        sequence_parallel_size: int = 1,
        record_memory: bool = False,
    ) -> None:
        """
        Args:
            pg: Placement group for the worker group. Accepts a single PlacementGroup, or None.
                Note that if colocate_all is True, the number of bundles in the placement group must match world_size.
        """
        self.cfg = cfg
        self._num_nodes = num_nodes
        self._num_gpus_per_node = num_gpus_per_node
        self.ray_actor_type = ray_actor_type

        # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html
        self._resources = resources
        self._num_resources_per_node = num_resources_per_node

        self.colocate_all = colocate_all
        self.sequence_parallel_size = sequence_parallel_size
        self.record_memory = record_memory
        self._initiate_actors(pg, num_gpus_per_actor)

    def _initiate_actors(self, pg: Optional[PlacementGroup], num_gpus_per_actor: float):
        """Initialize Ray actors in the worker group.

        Args:
            pg: A single placement group for the worker group, or None.
            num_gpus_per_actor: The number of gpus to allocate per actor.
        """
        world_size = self._num_nodes * self._num_gpus_per_node

        if self.colocate_all:
            assert (
                pg is not None
            ), "if colocate_all is True, the shared placement group must be provided to PPORayActorGroup"
            pg_data = placement_group_table(pg)
            assert len(pg_data["bundles"]) == world_size, (
                f"if colocate_all is True, the number of bundles in the placement group "
                f"must match world_size. Got {len(pg_data['bundles'])} bundles but world_size={world_size}"
            )

        # Build rank → bundle_index assignments sorted by (node_id, gpu_id)
        # for deterministic ordering.
        reordered_bundle_indices = []
        if pg is not None:
            pg_data = placement_group_table(pg)
            if len(pg_data["bundles"]) == world_size:
                reordered_bundle_indices = get_reordered_bundle_indices(pg)

        # If no PG provided, create one internally
        if pg is None and self._num_gpus_per_node > 1:
            bundles = [{"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes)]
            if self._resources:
                resources_name = list(self._resources.keys())[0]
                for i in range(len(bundles)):
                    bundles[i][resources_name] = self._num_resources_per_node

            internal_pg = placement_group(bundles, strategy="PACK")
            get_ray_pg_ready_with_timeout(internal_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
            pg = internal_pg

        def _scheduling_strategy_for_rank(rank):
            if reordered_bundle_indices:
                return PlacementGroupSchedulingStrategy(
                    placement_group=pg,
                    placement_group_bundle_index=reordered_bundle_indices[rank],
                )
            elif pg is not None:
                return PlacementGroupSchedulingStrategy(
                    placement_group=pg,
                    placement_group_bundle_index=rank // self._num_gpus_per_node,
                )
            # else we are in the single gpu case per node case in which case we don't need to set
            # bundle indices
            return None

        sched = _scheduling_strategy_for_rank(0)
        actor_options = {
            "num_cpus": num_gpus_per_actor,
            "num_gpus": num_gpus_per_actor,
            "resources": self._resources,
        }
        if sched is not None:
            actor_options["scheduling_strategy"] = sched

        master_actor = self.ray_actor_type.options(**actor_options).remote(
            cfg=self.cfg,
            world_size=world_size,
            rank=0,
            local_rank=0,
            master_addr=None,
            master_port=None,
            sequence_parallel_size=self.sequence_parallel_size,
            record_memory=self.record_memory,
        )
        self._actor_handlers = [master_actor]

        if world_size > 1:
            master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote())
            for rank in range(1, world_size):
                local_rank = rank % self._num_gpus_per_node

                sched = _scheduling_strategy_for_rank(rank)
                actor_options = {
                    "num_cpus": num_gpus_per_actor,
                    "num_gpus": num_gpus_per_actor,
                    "resources": self._resources,
                }
                if sched is not None:
                    actor_options["scheduling_strategy"] = sched

                worker_actor = self.ray_actor_type.options(**actor_options).remote(
                    cfg=self.cfg,
                    world_size=world_size,
                    rank=rank,
                    local_rank=local_rank,
                    master_addr=master_addr,
                    master_port=master_port,
                    sequence_parallel_size=self.sequence_parallel_size,
                    record_memory=self.record_memory,
                )
                self._actor_handlers.append(worker_actor)

        # Initialize process group
        logger.info("Initializing process group for RayActorGroup")
        ray.get([actor.init_worker_process_group.remote() for actor in self._actor_handlers])
        logger.info("Initialized process group for RayActorGroup")
        self.actor_infos = [ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) for actor in self._actor_handlers]
        logger.info(f"Mesh Ranks: {[actor_info.rank for actor_info in self.actor_infos]}")

    def async_init_model(
        self,
        *args,
        **kwargs,
    ) -> List[ObjectRef]:
        """Asynchronously initialize worker state (model, and optimizer if applicable) from model path
        on all the workers.

        Returns:
            A list of ray object refs.
        """
        return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]

    def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
        """Offload all worker state to CPU.

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of object refs.
        """
        refs = [
            actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

    def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
        """Backload worker state to GPU

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of ObjectRefs.
        """
        refs = [
            actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

    def run_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> Optional[TrainingOutputBatch]:
        """Run a method on all actors using specified dispatch type synchronously.

        The method should either return `None` or a `TrainingOutputBatch` object.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            Collect results from all the actors.
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        # Collect results from all the actors
        ret = dispatch_class.sync_collect(self.actor_infos, object_refs)
        return ret

    def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
        """Run a method on all actors using specified dispatch type asynchronously.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            List of object references
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        return object_refs

    async def async_run_method(
        self, dispatch_type: str, method_name: str, *args, **kwargs
    ) -> Optional[TrainingOutputBatch]:
        """Run a method on all actors using specified dispatch type in an asyncio-compatible way.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            TrainingOutputBatch: concatenated results from all actors
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        return await dispatch_class.async_collect(self.actor_infos, object_refs)

attr cfg

cfg = cfg

attr ray_actor_type

ray_actor_type = ray_actor_type

attr colocate_all

colocate_all = colocate_all

attr sequence_parallel_size

sequence_parallel_size = sequence_parallel_size

attr record_memory

record_memory = record_memory

method async_init_model

async_init_model(*args, **kwargs) -> List[ObjectRef]

Asynchronously initialize worker state (model, and optimizer if applicable) from model path on all the workers.

Returns:

TypeDescription
List[ObjectRef]A list of ray object refs.
Source code in skyrl/backends/skyrl_train/workers/worker.py:556-567
    def async_init_model(
        self,
        *args,
        **kwargs,
    ) -> List[ObjectRef]:
        """Asynchronously initialize worker state (model, and optimizer if applicable) from model path
        on all the workers.

        Returns:
            A list of ray object refs.
        """
        return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]

method offload_to_cpu

offload_to_cpu(nonblocking = False, offload_optimizer = True, offload_model = True)

Offload all worker state to CPU.

Parameters:

NameTypeDescriptionDefault
nonblockingWhether this operation is synchronous or asynchronous.False
Source code in skyrl/backends/skyrl_train/workers/worker.py:569-582
    def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
        """Offload all worker state to CPU.

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of object refs.
        """
        refs = [
            actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

method backload_to_gpu

backload_to_gpu(nonblocking = False, backload_optimizer = True, backload_model = True)

Backload worker state to GPU

Parameters:

NameTypeDescriptionDefault
nonblockingWhether this operation is synchronous or asynchronous.False
Source code in skyrl/backends/skyrl_train/workers/worker.py:584-597
    def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
        """Backload worker state to GPU

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of ObjectRefs.
        """
        refs = [
            actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

method run_method

run_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> Optional[TrainingOutputBatch]

Run a method on all actors using specified dispatch type synchronously.

The method should either return None or a TrainingOutputBatch object.

Parameters:

NameTypeDescriptionDefault
dispatch_typestrType of dispatch to use ("mesh" or "pass_through")required
method_namestrName of the method to call on actorsrequired
*argsPositional arguments to pass to the method()
**kwargsKeyword arguments to pass to the method{}

Returns:

TypeDescription
Optional[TrainingOutputBatch]Collect results from all the actors.
Source code in skyrl/backends/skyrl_train/workers/worker.py:599-621
    def run_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> Optional[TrainingOutputBatch]:
        """Run a method on all actors using specified dispatch type synchronously.

        The method should either return `None` or a `TrainingOutputBatch` object.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            Collect results from all the actors.
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        # Collect results from all the actors
        ret = dispatch_class.sync_collect(self.actor_infos, object_refs)
        return ret

method async_run_ray_method

async_run_ray_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> List[ObjectRef]

Run a method on all actors using specified dispatch type asynchronously.

Parameters:

NameTypeDescriptionDefault
dispatch_typestrType of dispatch to use ("mesh" or "pass_through")required
method_namestrName of the method to call on actorsrequired
*argsPositional arguments to pass to the method()
**kwargsKeyword arguments to pass to the method{}

Returns:

TypeDescription
List[ObjectRef]List of object references
Source code in skyrl/backends/skyrl_train/workers/worker.py:623-641
    def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
        """Run a method on all actors using specified dispatch type asynchronously.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            List of object references
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        return object_refs

method async async_run_method

async_run_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> Optional[TrainingOutputBatch]

Run a method on all actors using specified dispatch type in an asyncio-compatible way.

Parameters:

NameTypeDescriptionDefault
dispatch_typestrType of dispatch to use ("mesh" or "pass_through")required
method_namestrName of the method to call on actorsrequired
*argsPositional arguments to pass to the method()
**kwargsKeyword arguments to pass to the method{}

Returns:

NameTypeDescription
TrainingOutputBatchOptional[TrainingOutputBatch]concatenated results from all actors
Source code in skyrl/backends/skyrl_train/workers/worker.py:643-663
    async def async_run_method(
        self, dispatch_type: str, method_name: str, *args, **kwargs
    ) -> Optional[TrainingOutputBatch]:
        """Run a method on all actors using specified dispatch type in an asyncio-compatible way.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            TrainingOutputBatch: concatenated results from all actors
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        return await dispatch_class.async_collect(self.actor_infos, object_refs)

On this page

Trainer Classclass RayPPOTrainerattr cfgattr colocate_allattr trackerattr tokenizerattr train_datasetattr eval_datasetattr inference_engine_clientattr generatorattr train_dataloaderattr total_training_stepsattr eval_dataloaderattr colocate_pgattr resume_modeattr all_metricsattr all_timingsattr global_stepattr policy_modelattr critic_modelattr ref_modelattr dynamic_sampling_stateattr reward_kl_controllerattr dispatchattr property has_criticmethod async evalmethod async trainmethod build_modelsmethod init_weight_sync_statemethod sync_policy_weights_to_inference_enginesmethod convert_to_training_inputmethod async generatemethod postprocess_generator_outputmethod compute_advantages_and_returnsmethod dump_datamethod pad_batchmethod fwd_logprobs_values_rewardmethod apply_reward_kl_penaltymethod train_critic_and_policymethod handle_dynamic_samplingmethod save_checkpointsmethod load_checkpointsmethod save_modelsmethod update_ref_with_policyDispatch APIsclass Dispatchattr dispatchmethod abstractmethod async classmethod async_collectmethod abstractmethod classmethod sync_collectmethod abstractmethod classmethod validate_dispatch_argsclass MeshDispatchattr dispatchmethod abstractmethod async classmethod async_collectmethod abstractmethod classmethod sync_collectmethod classmethod dispatch_from_stagedmethod abstractmethod classmethod validate_dispatch_argsclass PassThroughDispatchattr dispatchmethod abstractmethod async classmethod async_collectmethod abstractmethod classmethod sync_collectmethod abstractmethod classmethod validate_dispatch_argsWorker APIsclass Workerattr sequence_parallel_sizeattr record_memoryget_node_local_rankinit_worker_process_groupget_mesh_rankget_gpu_idget_ray_node_idget_master_addr_portattr cfgmethod init_modelmethod empty_cachemethod offload_to_cpumethod backload_to_gpumethod get_cuda_memorymethod save_memory_snapshotmethod init_weight_sync_statemethod abstractmethod forwardclass PPORayActorGroupattr cfgattr ray_actor_typeattr colocate_allattr sequence_parallel_sizeattr record_memorymethod async_init_modelmethod offload_to_cpumethod backload_to_gpumethod run_methodmethod async_run_ray_methodmethod async async_run_method