SkyRL
API ReferenceSkyRLSkyRL-Train Backend

Trainer

Trainer API — RayPPOTrainer, Dispatch, Worker APIs.

Trainer Class

class RayPPOTrainer

RayPPOTrainer(cfg: SkyRLTrainConfig, tracker: Tracking, tokenizer: AutoTokenizer, train_dataset: Optional[PromptDataset], inference_engine_client: InferenceEngineClient, generator: GeneratorInterface, colocate_pg: Optional[ResolvedPlacementGroup] = None, eval_dataset: Optional[PromptDataset] = None, callbacks: Optional[List[TrainingCallback]] = None)

Functions:

NameDescription
add_callbackRegister a callback. Events fired after this call reach the new callback.
evalRun generation and scoring on the evaluation dataset.
trainMain training loop for PPO
build_modelsInitialize the actors for training, and handle colocation logic
init_weight_sync_stateSetup the connection between policy model and inference engine for weight syncing.
convert_to_training_inputConverts lists to a padded batch of tensors for training
generateGenerate rollouts.
postprocess_generator_outputConverts to per token rewards and computes pass@N.
compute_advantages_and_returnsCalculate advantages and returns for the data batch.
dump_dataDump data to pickle file
fwd_logprobs_values_rewardCalculate values from the critic, log probs from the policy and ref model.
apply_reward_kl_penaltyApplies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards.
train_critic_and_policyRun the training step for the policy and critic models.
handle_dynamic_samplingHandle dynamic sampling for the current batch.
save_checkpointsSave the model, optimizer, and training states to disk. Returns the
load_checkpointsLoad complete checkpoint state and return the global_step to resume from.
save_modelsSave the model parameters in HF format at cfg.trainer.export_path.
update_ref_with_policyUpdate the reference model with the policy model weights (required by some algorithms).

Attributes:

Source code in skyrl/train/trainer.py:96-1632
class RayPPOTrainer:
    def __init__(
        self,
        cfg: SkyRLTrainConfig,
        tracker: Tracking,
        tokenizer: AutoTokenizer,
        train_dataset: Optional[PromptDataset],
        inference_engine_client: InferenceEngineClient,
        generator: GeneratorInterface,
        colocate_pg: Optional[ResolvedPlacementGroup] = None,
        eval_dataset: Optional[PromptDataset] = None,
        callbacks: Optional[List[TrainingCallback]] = None,
    ):
        self.cfg = cfg
        self.colocate_all = cfg.trainer.placement.colocate_all
        self.tracker = tracker
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.inference_engine_client = inference_engine_client
        self.generator = generator
        self.train_dataloader = None
        self.total_training_steps = None
        self._build_train_dataloader_and_compute_training_steps()

        self.eval_dataloader = (
            build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else None
        )
        self.colocate_pg = colocate_pg

        self.resume_mode = ResumeMode(cfg.trainer.resume_mode)

        self.all_metrics = {}
        self.all_timings = {}
        self.global_step = 0

        self._vllm_metrics_scraper: Optional[VLLMMetricsScraper] = (
            VLLMMetricsScraper() if cfg.generator.inference_engine.enable_ray_prometheus_stats else None
        )

        self._ray_gpu_monitor = RayGpuMonitor() if cfg.trainer.enable_ray_gpu_monitor else None

        # initialized in `build_models`
        self.policy_model: PPORayActorGroup = None
        self.critic_model: Optional[PPORayActorGroup] = None
        self.ref_model: Optional[PPORayActorGroup] = None
        # used for checkpoint cleanup
        self._node_ids: Optional[List[str]] = None

        self.dynamic_sampling_state: Optional[DynamicSamplingState] = None

        self.reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None
        self.dispatch: WorkerDispatch = None

        self._callback_handler = CallbackHandler(callbacks)
        self._training_control = TrainingControl()
        self._current_epoch: int = 0

        configure_ray_worker_logging()

        self._num_training_gpus = (
            cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
        )

    def add_callback(self, callback: TrainingCallback) -> None:
        """Register a callback. Events fired after this call reach the new callback."""
        self._callback_handler.add(callback)

    def _build_callback_input(self, **fields) -> CallbackInput:
        """Snapshot loop counters + per-event fields into a CallbackInput."""
        steps_per_epoch = len(self.train_dataloader) if self.train_dataloader is not None else 0
        total_steps = self.total_training_steps or 0
        return CallbackInput(
            global_step=self.global_step,
            epoch=self._current_epoch,
            total_steps=total_steps,
            steps_per_epoch=steps_per_epoch,
            **fields,
        )

    def _fire(self, event_name: str, **fields) -> None:
        """Build a CallbackInput and dispatch the given event to all callbacks."""
        cb_input = self._build_callback_input(**fields)
        getattr(self._callback_handler, event_name)(self, cb_input, self._training_control)

    @property
    def has_critic(self) -> bool:
        """Check if critic model is configured."""
        return bool(self.cfg.trainer.critic.model.path)

    def _build_train_dataloader_and_compute_training_steps(self):
        """
        Hook for constructing the training dataloader. Subclasses can override
        this to customize dataloader behavior. For instance, fully async training
        needs a batch size of 1, among other features.
        Defaults to `trainer_utils.build_dataloader` with `is_train=True`.
        When train_dataset is None (e.g. Tinker backend provides data externally),
        the dataloader is not built.
        """
        if self.train_dataset is not None:
            self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True)
            self.total_training_steps = len(self.train_dataloader) * self.cfg.trainer.epochs

    @torch.no_grad()
    async def eval(self) -> Dict[str, float]:
        """
        Run generation and scoring on the evaluation dataset.

        The eval metrics are recorded after having finished training `self.global_step` steps.
        Metrics recorded in global_step 0 corresponds to evaluations before training.

        Returns:
            A dictionary of evaluation metrics.
        """
        if self.cfg.generator.step_wise_trajectories:
            eval_metrics = await evaluate_step_wise(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        else:
            eval_metrics = await evaluate(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        return eval_metrics

    async def train(self):
        """
        Main training loop for PPO
        """
        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.start()

        # Initialize weight sync state between policy model and inference engines.
        with Timer("init_weight_sync_state"):
            self.init_weight_sync_state()

        # Load checkpoint state if resumption is enabled.
        if self.resume_mode != ResumeMode.NONE:
            with Timer("load_checkpoints"):
                self.global_step, _ = self.load_checkpoints()

        # Prepare weights for sampling
        with Timer("sync_weights"):
            await self.dispatch.save_weights_for_sampler()

        # Compute start_epoch up-front so callback metadata is ready before
        # any event fires (including the baseline eval below).
        start_epoch = self.global_step // len(self.train_dataloader)
        self._current_epoch = start_epoch
        self._training_control.reset()

        self._fire("on_train_start")

        # Eval before training. Wrapped in eval callbacks + on_log so that e.g.
        # a best-checkpoint callback sees the baseline reading.
        if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
            self._fire("on_eval_start")
            with Timer("eval", self.all_timings):
                eval_metrics = await self.eval()
            self._fire("on_eval_end", metrics=eval_metrics)
            self._fire("on_log", logs=eval_metrics)
            self.tracker.log(eval_metrics, step=self.global_step, commit=True)

        # initialize kl controller
        if self.cfg.trainer.algorithm.use_kl_in_reward:
            self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)

        # main training loop
        pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
        self.global_step += 1  # start training at global_step 1

        # booleans tracking whether we save ckpts
        # as well as hf model at step end
        will_save_ckpts = False
        hf_model_save = False
        for epoch in range(start_epoch, self.cfg.trainer.epochs):
            self._current_epoch = epoch
            self._fire("on_epoch_start")
            # ``step_started`` tracks the on_step_start/on_step_end pairing taking
            # dynamic-sampling into account (which span multiple inner iterations
            # before completing a logical step).
            step_started = False
            for _, rand_prompts in enumerate(self.train_dataloader):
                if not step_started:
                    self._fire("on_step_start")
                    step_started = True
                with Timer("step", self.all_timings):
                    # for colocate_all=true, inference engine is always on GPU when starting the training step

                    # 0. truncate data to have even shards
                    rand_prompts = self._remove_tail_data(rand_prompts)
                    generator_input, uids = prepare_generator_input(
                        rand_prompts,
                        self.cfg.generator.n_samples_per_prompt,
                        get_sampling_params_for_backend(
                            self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
                        ),
                        self.cfg.environment.env_class,
                        "train",
                        self.global_step,
                    )

                    # 1.1. generation phase
                    with Timer("generate", self.all_timings):
                        generator_output: GeneratorOutput = await self.generate(generator_input)

                    if self.cfg.generator.step_wise_trajectories:
                        # NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
                        # this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
                        uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]

                    # dynamic sampling
                    if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
                        generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
                        if keep_sampling:  # continue sampling
                            # update progress bar for current batch (but not global step)
                            pbar.update(1)
                            continue

                    if self.colocate_all:
                        # if we are not continuing sampling, we sleep the inference engine
                        await self.inference_engine_client.sleep()

                    # 1.2 postprocess rewards (and merge step-wise turns if enabled)
                    with Timer("postprocess_generator_output", self.all_timings):
                        generator_output, uids = self.postprocess_generator_output(generator_output, uids)

                    # 2. print example just for debugging
                    log_interval = self.cfg.trainer.log_example_interval
                    if log_interval > 0 and self.global_step % log_interval == 0:
                        vis = self.tokenizer.decode(generator_output["response_ids"][0])
                        log_example(
                            logger,
                            prompt=generator_input["prompts"][0],
                            response=vis,
                            reward=generator_output["rewards"][0],
                        )

                    # 3. Convert GeneratorOutput to TrainingInputBatch
                    with Timer("convert_to_training_input", self.all_timings):
                        training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)

                    # 4. Inference and calculate values, log probs, rewards, kl divergence
                    with Timer("fwd_logprobs_values_reward", self.all_timings):
                        training_input = self.fwd_logprobs_values_reward(training_input)

                    # 5. apply kl divergence penalty to rewards
                    if self.cfg.trainer.algorithm.use_kl_in_reward:
                        with Timer("apply_reward_kl_penalty", self.all_timings):
                            training_input = self.apply_reward_kl_penalty(training_input)

                    # 6. calculate advantages and returns
                    with Timer("compute_advantages_and_returns", self.all_timings):
                        training_input = self.compute_advantages_and_returns(training_input)
                        # remove some unwanted keys
                        for key in ["rewards"]:
                            training_input.pop(key)
                        training_input.metadata.pop("uids")
                        training_input.metadata.pop("is_last_step", None)

                    if self.cfg.trainer.dump_data_batch:
                        # dump data to file
                        with Timer("dump_data_batch"):
                            self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")

                    # 7. train policy/critic model
                    # Policy model is backloaded to GPU during training
                    with Timer("train_critic_and_policy", self.all_timings):
                        status = self.train_critic_and_policy(training_input)

                    self._fire("on_step_end", batch=training_input, metrics=status)
                    step_started = False

                    # Capture callback-driven triggers, then reset.
                    force_save = self._training_control.should_save
                    force_eval = self._training_control.should_evaluate
                    self._training_control.should_save = False
                    self._training_control.should_evaluate = False

                    # 8. conditionally save checkpoints and hf model
                    is_epoch_end = self.global_step % len(self.train_dataloader) == 0
                    hf_model_save = self.cfg.trainer.hf_save_interval > 0 and (
                        is_epoch_end or self.global_step % self.cfg.trainer.hf_save_interval == 0
                    )
                    ckpt_interval_save = self.cfg.trainer.ckpt_interval > 0 and (
                        is_epoch_end or self.global_step % self.cfg.trainer.ckpt_interval == 0
                    )
                    will_save_ckpts = force_save or ckpt_interval_save
                    if will_save_ckpts:
                        with Timer("save_checkpoints", self.all_timings):
                            ckpt_path = self.save_checkpoints()
                        self._fire("on_save", ckpt_path=ckpt_path)
                    if hf_model_save:
                        with Timer("save_hf_model", self.all_timings):
                            self.save_models()

                    # 9. conditionally sync policy and ref at the end of the epoch
                    if (
                        self.cfg.trainer.update_ref_every_epoch
                        and self.ref_model is not None
                        and is_epoch_end
                        and epoch != self.cfg.trainer.epochs - 1  # skip updating ref at the end of the last epoch
                    ):
                        with Timer("update_ref_with_policy", self.all_timings):
                            self.update_ref_with_policy()

                    # 10. Prepare weights for sampling
                    with Timer("sync_weights", self.all_timings):
                        await self.dispatch.save_weights_for_sampler()

                # 11. set logs
                logger.info(status)
                # Throughput metrics
                train_time = self.all_timings.get("train_critic_and_policy", 0.0)
                if train_time > 0 and training_input.get("attention_mask") is not None:
                    total_tokens = int(training_input["attention_mask"].sum().item())
                    self.all_metrics["trainer/tokens_per_second_per_gpu"] = total_tokens / (
                        train_time * self._num_training_gpus
                    )
                # log epoch info
                self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
                interval_eval = self.cfg.trainer.eval_interval > 0 and (
                    self.global_step % self.cfg.trainer.eval_interval == 0
                    or self.global_step == self.total_training_steps
                )
                if force_eval or interval_eval:
                    self._fire("on_eval_start")
                    with Timer("eval", self.all_timings):
                        eval_metrics = await self.eval()
                        self.all_metrics.update(eval_metrics)
                    self._fire("on_eval_end", metrics=eval_metrics)

                log_payload = {
                    **self.all_metrics,
                    **{f"timing/{k}": v for k, v in self.all_timings.items()},
                }
                if self._vllm_metrics_scraper is not None:
                    log_payload.update(await self._vllm_metrics_scraper.sample())

                if self._ray_gpu_monitor is not None:
                    log_payload.update(self._ray_gpu_monitor.flush())

                self._fire("on_log", logs=log_payload)

                self.tracker.log(log_payload, step=self.global_step, commit=True)
                self.all_metrics = {}
                self.all_timings = {}

                # update progress bar after logging
                pbar.update(1)

                self.global_step += 1

                del training_input, generator_output

            self._fire("on_epoch_end")

        pbar.close()
        if self.colocate_all:
            await self.inference_engine_client.sleep()

        # Decrement global step by 1 to stop at the last global step
        # We use the global step value in callbacks when training finishes,
        # as well as for a final checkpoint save
        self.global_step -= 1

        # Safety net: always save final checkpoint at end of training.
        # Skip if we already saved at the last step
        if self.cfg.trainer.ckpt_interval > 0 and not will_save_ckpts:
            with Timer("save_checkpoints", self.all_timings):
                ckpt_path = self.save_checkpoints()
                logger.info("Saved final checkpoint.")
            self._fire("on_save", ckpt_path=ckpt_path)
        if self.cfg.trainer.hf_save_interval > 0 and not hf_model_save:
            with Timer("save_hf_model", self.all_timings):
                self.save_models()
                logger.info("Saved final model.")
        if self._vllm_metrics_scraper is not None:
            await self._vllm_metrics_scraper.aclose()

        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.stop()

        self._fire("on_train_end")
        self.tracker.finish()
        logger.info("Training done!")

    def _remove_tail_data(self, entries: List[Any]) -> List[Any]:
        """Remove tail data to have even shards in terms of *effective* samples.

        Each prompt produces `n_samples_per_prompt` samples. For data-parallel
        training we care that the total number of samples is nicely splittable
        across the (combined) data-parallel size of all enabled models.
        """
        lcm_dp_size = self.dispatch.get_lcm_dp_size()

        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt

        # We want the largest m <= len(entries) such that:
        #   (m * n_samples_per_prompt) % lcm_dp_size == 0
        #
        # Let g = gcd(lcm_dp_size, n_samples_per_prompt). Then this is equivalent
        # to requiring m to be a multiple of (lcm_dp_size / g).
        stride = lcm_dp_size // math.gcd(lcm_dp_size, n_samples_per_prompt)
        if stride <= 1:
            # Every prompt count is valid, keep all entries.
            return entries

        kept_prompts = (len(entries) // stride) * stride
        return entries[:kept_prompts]

    def build_models(self, PolicyWorker, CriticWorker, RefWorker):
        """
        Initialize the actors for training, and handle colocation logic
        """
        cfg = self.cfg
        pg = None

        use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward

        if cfg.trainer.placement.colocate_all:
            num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
            num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
            num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
            ie_cfg = cfg.generator.inference_engine
            num_rollout_gpus = (
                ie_cfg.num_engines
                * ie_cfg.tensor_parallel_size
                * ie_cfg.pipeline_parallel_size
                * ie_cfg.data_parallel_size
            )
            assert (
                num_policy_gpus == num_rollout_gpus
            ), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
            pg = self.colocate_pg

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.2 if pg else 1,
                colocate_all=True,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
                record_memory=cfg.trainer.policy.record_memory,
            )
            if use_ref_model:
                assert (
                    num_policy_gpus == num_ref_gpus
                ), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2 if pg else 1,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                assert (
                    num_policy_gpus == num_critic_gpus
                ), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        else:
            if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
                assert (
                    cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
                    and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
                ), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."

                bundles = [
                    {
                        "GPU": cfg.trainer.placement.policy_num_gpus_per_node,
                        "CPU": cfg.trainer.placement.policy_num_gpus_per_node,
                    }
                    for _ in range(cfg.trainer.placement.policy_num_nodes)
                ]
                raw_pg = placement_group(bundles, strategy="PACK")
                get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
                pg = ResolvedPlacementGroup(raw_pg)

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.75 if pg else 1,
                colocate_all=False,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
            )
            if use_ref_model:
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.25 if pg else 1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
                if pg is not None:
                    # The shared policy/ref placement group `pg` is set only when colocate_policy_ref is enabled
                    logger.info(
                        "Colocating policy and ref on the same GPUs across "
                        f"{cfg.trainer.placement.policy_num_nodes} node(s)."
                    )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    num_gpus_per_actor=1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        policy_steps_per_train_batch = (
            cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
        )
        critic_steps_per_train_batch = 0
        if cfg.trainer.critic.model.path:
            critic_steps_per_train_batch = (
                cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
            )
        policy_num_training_steps = (
            self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
        )
        critic_num_training_steps = (
            self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
        )
        if not cfg.trainer.placement.colocate_all:
            refs = []
            if ref_model is not None:
                refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
            refs.extend(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            if cfg.trainer.critic.model.path:
                refs.extend(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
            ray.get(refs)
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
        else:
            if ref_model is not None:
                ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
                ref_model.offload_to_cpu()
            ray.get(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
            policy_model.offload_to_cpu()
            if cfg.trainer.critic.model.path:
                ray.get(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
                critic_model.offload_to_cpu()

        self.policy_model: PPORayActorGroup = policy_model
        self.critic_model: Optional[PPORayActorGroup] = critic_model
        self.ref_model: Optional[PPORayActorGroup] = ref_model

        # Create unified dispatch that manages all actor groups
        self.dispatch = WorkerDispatch(
            cfg=self.cfg,
            policy_actor_group=policy_model,
            critic_actor_group=critic_model,
            ref_actor_group=ref_model,
            inference_engine_client=self.inference_engine_client,
        )

        # Mark all models as offloaded if colocate_all (they were offloaded above)
        if self.colocate_all:
            self.dispatch.mark_all_offloaded()

        logger.info("init policy/ref/critic models done")

    def init_weight_sync_state(self):
        """
        Setup the connection between policy model and inference engine for weight syncing.
        """
        self.dispatch.init_weight_sync_state(self.inference_engine_client)
        logger.info("Initialized weight sync state for policy model and inference engines.")

    def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
        """Converts lists to a padded batch of tensors for training

        Args:
            generator_output (GeneratorOutput): Generated rollouts and associated data.
            uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same
                order as `generator_output`. Used to identify which prompt each generated rollout belongs to.
        Returns:
            training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
                order of `generator_output` and hence `uids`.
        """
        # 1. Extract generator output fields.
        prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
        response_ids: List[List[int]] = generator_output["response_ids"]
        rewards: List[List[float]] = generator_output["rewards"]
        loss_masks: List[List[int]] = generator_output["loss_masks"]

        logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)
        rollout_expert_indices: Optional[List[List[List[List[int]]]]] = generator_output.get(
            "rollout_expert_indices", None
        )

        pixel_values = generator_output.get("pixel_values", None)
        image_grid_thw = generator_output.get("image_grid_thw", None)
        if pixel_values is not None:
            assert (
                pixel_values is not None and image_grid_thw is not None
            ), "Both pixel_values and image_grid_thw must exist for multi-modal inputs"
            assert len(pixel_values) == len(
                image_grid_thw
            ), "Number of pixel values should match number of image grid thw"
            pixel_values = TensorList(pixel_values)
            image_grid_thw = TensorList(image_grid_thw)

        # 2. Convert to tensors.
        (
            sequences_tensor,
            attention_masks_tensor,
            response_masks_tensor,
            rewards_tensor,
            loss_masks_tensor,
            rollout_logprobs_tensor,
            rollout_expert_indices_tensor,
        ) = convert_prompts_responses_to_batch_tensors(
            self.tokenizer,
            prompt_ids,
            response_ids,
            rewards,
            loss_masks,
            logprobs,
            rollout_expert_indices,
            max_seq_len=self.cfg.trainer.algorithm.max_seq_len,
        )

        # sanity check for off_policy_correction
        off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
        tis_ratio_type = off_policy_correction.tis_ratio_type
        sequence_mask_metric = off_policy_correction.sequence_mask_metric
        if tis_ratio_type is not None or sequence_mask_metric is not None:
            assert (
                rollout_logprobs_tensor is not None
            ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
            assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

        # 3. Create training input batch.
        training_input = TrainingInputBatch(
            {
                "sequences": sequences_tensor,  # Full trajectories (padded and concatenated prompts and responses)
                "attention_mask": attention_masks_tensor,
                "response_mask": response_masks_tensor,
                "rewards": rewards_tensor,
                "loss_mask": loss_masks_tensor,
                "rollout_logprobs": rollout_logprobs_tensor,
                "rollout_expert_indices": rollout_expert_indices_tensor,
                "pixel_values": pixel_values,
                "image_grid_thw": image_grid_thw,
            },
        )
        training_input.metadata = {"uids": uids}
        if generator_output.get("is_last_step", None) is not None:
            training_input.metadata["is_last_step"] = generator_output["is_last_step"]

        # 4. Compute mini-batch boundaries for train_critic_and_policy(). It excludes the ones
        # we will add in pad_training_input_batch().
        train_batch_size = self.cfg.trainer.train_batch_size
        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
        is_stepwise = self.cfg.generator.step_wise_trajectories
        training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
            uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
        )
        # Per-prompt boundaries (used by the `prompt_mean` loss reduction). Policy-only,
        # since advantage normalization only applies to the policy.
        training_input.metadata["policy_prompt_boundaries"] = compute_prompt_boundaries(uids)
        if self.cfg.trainer.critic.model.path is not None:
            training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
                uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
            )

        # 5. Record metadata and metrics.
        training_input.metadata["response_length"] = response_masks_tensor.shape[1]
        batch_num_seq, batch_padded_seq_len = sequences_tensor.shape
        logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}")
        self.all_metrics.update(
            {
                "generate/batch_num_seq": batch_num_seq,
                "generate/batch_padded_seq_len": batch_padded_seq_len,
            }
        )
        training_input.metadata["avg_response_length"] = sum(
            len(sample_response_ids) for sample_response_ids in response_ids
        ) / len(response_ids)

        # 6. Pad the batch, only needed for step-wise training's `fwd_logprobs_values_reward()`.
        logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
        dp_size = self.dispatch.get_lcm_dp_size()
        pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
        training_input = pad_training_input_batch(training_input, pad_size)
        logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")

        return training_input

    @torch.no_grad()
    async def generate(
        self,
        input_batch: GeneratorInput,
    ) -> GeneratorOutput:
        """
        Generate rollouts.

        If colocate_all is enabled:
        - before calling this method, the policy model should be on CPU and inference engine should
            be awake (i.e. on GPU).
        - after calling this method, the same model placement still holds.
        """
        # NOTE: we assume that .generate returns samples in the same order as passed in
        generator_output: GeneratorOutput = await self.generator.generate(input_batch)

        # add rollout metrics to self.all_metrics
        if generator_output["rollout_metrics"] is not None:
            self.all_metrics.update(generator_output["rollout_metrics"])
        generator_output.pop("rollout_metrics", None)

        validate_generator_output(
            len(input_batch["prompts"]),
            generator_output,
            step_wise=self.cfg.generator.step_wise_trajectories,
        )

        return generator_output

    @torch.no_grad()
    def postprocess_generator_output(
        self, generator_output: GeneratorOutput, uids: List[str]
    ) -> Tuple[GeneratorOutput, List[str]]:
        """
        Converts to per token rewards and computes pass@N.

        For step-wise training with ``merge_stepwise_output=true``, also collapses
        consecutive turns sharing a common prefix into a single sequence; ``uids``
        is shortened to match.

        In the future algorithm specific reward or loss mask post processing should be done here.

        Returns:
            (generator_output, uids) — uids may be shorter than the input when merging.
        """
        generator_output_for_metrics = generator_output
        uids_for_metrics = uids
        if self.cfg.generator.step_wise_trajectories:
            generator_output_for_metrics = defaultdict(list)
            for key in generator_output:
                if isinstance(generator_output[key], list):
                    generator_output_for_metrics[key] = [
                        generator_output[key][i]
                        for i in range(len(generator_output[key]))
                        if generator_output["is_last_step"][i]
                    ]
            uids_for_metrics = [
                uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
            ]

        # only use `generator_output_for_metrics` for metrics calculation
        # For step-wise training, we only calculate metrics for the last step of each trajectory
        overall_metrics = get_metrics_from_generator_output(
            generator_output_for_metrics,
            uids_for_metrics,
        )

        # Prefix-aware merging of step-wise turns.
        if self.cfg.generator.merge_stepwise_output:
            assert self.cfg.generator.step_wise_trajectories, "merge_stepwise_output requires step-wise training"
            num_seq_before_merge = len(generator_output["response_ids"])
            generator_output = merge_stepwise_output(generator_output)
            num_seq_after_merge = len(generator_output["response_ids"])
            logger.info(f"Merged step wise: {num_seq_before_merge} sequences -> {num_seq_after_merge} sequences")
            self.all_metrics.update(
                {
                    "generate/num_seq_before_merge": num_seq_before_merge,
                    "generate/num_seq_after_merge": num_seq_after_merge,
                }
            )
            uids = [tid.instance_id for tid in generator_output["trajectory_ids"]]

        # these use the full generator output
        rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
        responses: List[List[int]] = generator_output["response_ids"]
        per_token_rewards: List[List[float]] = []

        # Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
        if rewards and isinstance(rewards[0], list):
            # Token-level rewards: rewards is List[List[float]]
            per_token_rewards = rewards
        else:
            if self.cfg.trainer.algorithm.zero_variance_filter:
                kept_indices_set = set(zero_variance_filter(rewards, uids))
                generator_output["loss_masks"] = [
                    [0] * len(mask) if i not in kept_indices_set else mask
                    for i, mask in enumerate(generator_output["loss_masks"])
                ]
            # Response-level rewards: rewards is List[float], convert to per-token rewards
            for reward, response in zip(rewards, responses):
                per_token_reward = [0.0] * len(response)
                per_token_reward[-1] = float(reward)
                per_token_rewards.append(per_token_reward)

        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt

        reward_metrics = {
            f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
            "reward/avg_raw_reward": overall_metrics["avg_score"],
            "reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
        }
        self.all_metrics.update(reward_metrics)
        logger.info(
            f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
        )
        # re-assign reward but now it's per token rewards
        generator_output["rewards"] = per_token_rewards
        return generator_output, uids

    @torch.no_grad()
    def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
        """Calculate advantages and returns for the data batch.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `.metadata["uids"]`: List[str]
            - `.metadata["is_last_step"]`: List[bool] for step-wise training

        Adds:
            - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        token_level_rewards = data["rewards"]

        if self.cfg.generator.step_wise_trajectories:
            is_last_step = torch.tensor(data.metadata["is_last_step"], dtype=torch.bool)
            index = np.array(data.metadata["uids"])
            values = data["values"]
            # Step-wise only supports outcome-based estimators (GRPO, RLOO, MAXRL); ensured by `validate_cfg`.
            # We use the last step of each trajectory to compute advantages and broadcast them to
            # all steps of that trajectory, so we ignore per-step rewards in step-wise training.
            # We pass an all-ones mask here so the estimator returns the scalar advantage at every
            # position. The real per-step `response_mask` is re-applied on broadcast below.
            # Shapes:
            #   traj_ids, (batch_size,):         trajectory id per step (cumsum of shifted is_last_step)
            #   last_step_advantages/returns,
            #       (num_traj, seqlen):          scalar advantage/return per trajectory at every position
            #   last_step_advantages/returns[traj_ids],
            #       (batch_size, seqlen):        broadcast to every step of the owning trajectory
            #   response_mask_float,
            #       (batch_size, seqlen):        per-step response mask
            last_step_response_mask = data["response_mask"][is_last_step]
            last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards[is_last_step],
                response_mask=torch.ones_like(last_step_response_mask, dtype=torch.float),
                index=index[is_last_step.cpu().numpy()],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                values=values[is_last_step] if values is not None else None,
                config=self.cfg.trainer.algorithm,
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
            traj_ids = (
                torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
            )
            num_traj = traj_ids[-1].item() + 1
            assert num_traj == len(
                last_step_advantages
            ), f"num_traj {num_traj} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
            response_mask_float = data["response_mask"].to(last_step_advantages.dtype)
            advantages = last_step_advantages[traj_ids] * response_mask_float
            returns = last_step_returns[traj_ids] * response_mask_float
        else:
            advantages, returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards,
                response_mask=data["response_mask"],
                index=data.metadata["uids"],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                config=self.cfg.trainer.algorithm,
                values=data["values"],
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
        data["returns"] = returns
        data["advantages"] = advantages

        # remove padding while calculating metrics
        pad_size = data.metadata.get("pad_size", 0)
        num_samples = len(token_level_rewards)

        return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
        if self.cfg.generator.step_wise_trajectories:
            avg_rewards: float = return_sums[is_last_step[: num_samples - pad_size]].mean().item()
        else:
            avg_rewards: float = return_sums.mean().item()

        avg_response_length = data.metadata["avg_response_length"]
        data = data.to("cpu")

        valid_advantages = torch.masked_select(
            data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
        )
        avg_advantages: float = valid_advantages.mean().item()
        avg_advantages_abs: float = valid_advantages.abs().mean().item()

        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}
        data.metadata["metrics"].update(
            {
                "avg_final_rewards": avg_rewards,
                "avg_response_length": avg_response_length,
                "avg_advantages": avg_advantages,
                "avg_advantages_abs": avg_advantages_abs,
            }
        )

        logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
        self.all_metrics.update(
            {
                "loss/avg_final_rewards": avg_rewards,
                "loss/avg_raw_advantages": avg_advantages,
                "loss/avg_raw_advantages_abs": avg_advantages_abs,
            }
        )
        return data

    def dump_data(self, data: TrainingInputBatch, file_name: str):
        """
        Dump data to pickle file
        """
        data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
        data_save_dir.mkdir(parents=True, exist_ok=True)
        data.save(data_save_dir / f"{file_name}.pkl")

    @torch.no_grad()
    def fwd_logprobs_values_reward(
        self,
        training_input: TrainingInputBatch,
    ):
        """
        Calculate values from the critic, log probs from the policy and ref model.

        Dispatch handles offload/backload automatically for all colocation configurations.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `.metadata["response_length"]`: Int

        Adds:
            - `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        fwd_keys = ["sequences", "attention_mask"]
        if training_input.get("rollout_expert_indices") is not None:
            fwd_keys.append("rollout_expert_indices")
        if training_input.get("pixel_values") is not None:
            fwd_keys.append("pixel_values")
        if training_input.get("image_grid_thw") is not None:
            fwd_keys.append("image_grid_thw")
        data_fwd_pass = training_input.select(keys=fwd_keys, metadata_keys=["response_length"])

        values = None
        base_log_probs = None
        action_log_probs = None

        # Critic forward (dispatch handles offload/backload automatically)
        if self.has_critic:
            critic_output = self.dispatch.forward("critic", data_fwd_pass)
            values = loss_fn_outputs_to_tensor(critic_output.loss_fn_outputs, key="values")

        # Ref forward
        if self.ref_model is not None:
            ref_output = self.dispatch.forward("ref", data_fwd_pass)
            base_log_probs = loss_fn_outputs_to_tensor(ref_output.loss_fn_outputs, key="logprobs")
            self.dispatch.empty_cache("ref")

        # Policy forward
        policy_output = self.dispatch.forward("policy", data_fwd_pass)
        action_log_probs = loss_fn_outputs_to_tensor(policy_output.loss_fn_outputs, key="logprobs")

        # Empty cache after all forward passes
        self.dispatch.empty_cache()

        sequences_all: torch.Tensor = training_input["sequences"]
        # NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
        base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
        action_log_probs = action_log_probs[: len(sequences_all)]
        values = values[: len(sequences_all)] if values is not None else None

        training_input["base_action_log_probs"] = base_log_probs
        training_input["action_log_probs"] = action_log_probs
        training_input["values"] = values

        if training_input.get("rollout_logprobs", None) is not None:
            # calculates the difference in probs between inference and trainer components
            # only consider response tokens
            logprobs_diff = (
                training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
                - action_log_probs[training_input["loss_mask"] > 0]
            ).abs()

            logprobs_diff_max = logprobs_diff.max().item()
            logprobs_diff_min = logprobs_diff.min().item()
            logprobs_diff_mean = logprobs_diff.mean().item()
            logprobs_diff_std = logprobs_diff.std().item()
            self.all_metrics.update(
                {
                    "policy/rollout_train_logprobs_abs_diff_max": logprobs_diff_max,
                    "policy/rollout_train_logprobs_abs_diff_min": logprobs_diff_min,
                    "policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
                    "policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
                }
            )
        return training_input

    def apply_reward_kl_penalty(
        self,
        data: TrainingInputBatch,
    ) -> TrainingInputBatch:
        """Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
        loss_masks_all: torch.Tensor = data["loss_mask"]
        rewards: torch.Tensor = data["rewards"]
        base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
        action_log_probs: torch.Tensor = data["action_log_probs"]

        # single batched computation
        with torch.no_grad():
            kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl(  # type: ignore
                action_log_probs,
                base_action_log_probs,
                loss_mask=loss_masks_all,
                kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
            )
        kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0]  # noqa: F821
        kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1)  # noqa: F821

        # NOTE (erictang000): only supporting custom rewards currently
        kl_loss_coef = (
            self.reward_kl_controller.value
            if self.reward_kl_controller is not None
            else self.cfg.trainer.algorithm.kl_loss_coef
        )
        rewards = rewards - kl * max(0, kl_loss_coef)
        data["rewards"] = rewards

        avg_kl: float = kl_mean.mean().item()
        avg_kl_max: float = kl_max.mean().item()

        # update the kl controller
        if self.reward_kl_controller is not None:
            self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0])  # n_steps is just the batch size
        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}

        data.metadata["metrics"].update(
            {
                "avg_kl": avg_kl,
                "avg_kl_max": avg_kl_max,
                "kl_loss_coef": kl_loss_coef,
            }
        )

        self.all_metrics.update(
            {
                "loss/avg_kl": avg_kl,
                "loss/avg_kl_max": avg_kl_max,
                "loss/kl_loss_coef": kl_loss_coef,
            }
        )

        return data

    @torch.no_grad()
    def _normalize_advantages(
        self,
        data: TrainingInputBatch,
        mini_batch_boundaries: List[Tuple[int, int]],
        prompt_boundaries: Optional[List[Tuple[int, int]]] = None,
    ) -> TrainingInputBatch:
        advantages = data["advantages"]
        response_mask = data["response_mask"]

        # Step 1: Z-score normalization (if enabled)
        if self.cfg.trainer.algorithm.advantage_batch_normalize:
            num_actions = response_mask.sum()
            mean = advantages.mean()
            std = ((advantages - mean).pow(2) * response_mask).sum()
            rstd = (std / num_actions).clamp(min=1e-8).rsqrt()
            data["advantages"] = (advantages - mean) * rstd

        # Step 2: Loss reduction normalization per mini-batch
        normalized_advantages = torch.zeros_like(advantages)
        for start_idx, end_idx in mini_batch_boundaries:
            mini_batch = data[start_idx:end_idx]
            # For prompt_mean, select the prompt boundaries falling within this mini-batch
            # and rebase them to mini-batch-relative indices.
            mb_prompt_boundaries = None
            if prompt_boundaries is not None:
                mb_prompt_boundaries = [
                    (p_start - start_idx, p_end - start_idx)
                    for p_start, p_end in prompt_boundaries
                    if start_idx <= p_start < end_idx
                ]
            normalized_advantages[start_idx:end_idx] = apply_loss_reduction_to_advantages_minibatch(
                advantages=mini_batch["advantages"],
                loss_mask=mini_batch["loss_mask"],
                loss_reduction=self.cfg.trainer.algorithm.loss_reduction,
                micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu,
                max_seq_len=self.cfg.trainer.algorithm.max_seq_len,
                prompt_boundaries=mb_prompt_boundaries,
            )

        data["advantages"] = normalized_advantages
        return data

    def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]:
        """
        Execute training step using forward_backward + optim_step.

        The trainer loops over epochs and mini-batches. Workers handle micro-batching
        internally for gradient accumulation (memory efficiency).

        All per-DP mini-batch chunks are pre-staged in the Ray object store before
        the training loop so serialization stays off the GPU critical path.

        Args:
            model: Model name ("policy" or "critic")
            data: Training data batch

        Returns:
            Dict of reduced metrics from training
        """
        boundaries = data.metadata[f"{model}_mini_batch_boundaries"]

        if model == "policy":
            # Normalize advantages for policy training; critic training does not need this
            prompt_boundaries = data.metadata.get("policy_prompt_boundaries")
            data = self._normalize_advantages(data, boundaries, prompt_boundaries)

        all_metrics: Dict[str, List[float]] = defaultdict(list)

        # Pre-stage all per-DP mini-batch chunks in the object store so that
        # serialization is fully off the critical path during training.
        all_chunk_refs = self.dispatch.stage_data(model, data, boundaries)

        # Training loop over epochs and mini-batches
        for _epoch in range(self.cfg.trainer.update_epochs_per_batch):
            for chunk_refs in all_chunk_refs:
                status = self.dispatch.forward_backward_from_staged(model, chunk_refs)
                for k, v in status.metrics.items():
                    all_metrics[k].append(v)

                # Optimizer step after each mini batch
                grad_norm = self.dispatch.optim_step(model)
                if grad_norm is not None:
                    all_metrics["grad_norm"].append(grad_norm)

        # Reduce metrics across all mini-batches and epochs
        reduced_metrics = reduce_metrics(all_metrics, sum_loss_metrics=False)
        return reduced_metrics

    def train_critic_and_policy(self, data: TrainingInputBatch):
        """
        Run the training step for the policy and critic models.

        Uses forward_backward + optim_step for both FSDP and Megatron strategies.
        """
        data.metadata["global_step"] = self.global_step
        critic_status = None

        # Unified training interface for both FSDP and Megatron
        if self.has_critic:
            with Timer("critic_train", self.all_timings):
                critic_status = self._execute_training_step("critic", data)
        with Timer("policy_train", self.all_timings):
            policy_status = self._execute_training_step("policy", data)

        # Update metrics
        if critic_status is not None:
            for k, v in critic_status.items():
                self.all_metrics.update({f"critic/{k}": v})

        for k, v in policy_status.items():
            self.all_metrics.update({f"policy/{k}": v})

        self.dispatch.empty_cache()

        return policy_status

    def handle_dynamic_sampling(
        self, generator_output: GeneratorOutput, uids: List[str]
    ) -> Tuple[GeneratorOutput, List[str], bool]:
        """
        Handle dynamic sampling for the current batch.

        Accumulates the generator output and UIDs across batches if we are sampling repeatedly
        and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
        If we hit the limit of max sample batches, we raise an error.

        Args:
            generator_output: Current batch generator output
            uids: Current batch UIDs

        Returns:
            processed_output: Filtered generator output
            processed_uids: Filtered UIDs
            keep_sampling: Whether to keep sampling
        """
        # Prepare sampling configuration
        max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
        dynamic_sampling_config = {
            "type": self.cfg.trainer.algorithm.dynamic_sampling.type,
            "max_sample_batches": max_sample_batches,
            "min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
            "train_batch_size": self.cfg.trainer.train_batch_size,
            "n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
        }

        if self.dynamic_sampling_state is None:
            self.dynamic_sampling_state: DynamicSamplingState = {
                "sample_batch_count": 1,
            }
        else:
            self.dynamic_sampling_state["sample_batch_count"] += 1

        # Handle dynamic sampling using utilities
        processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
            generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
        )

        # Check max resample limit, and if we hit it, raise an error
        if (
            keep_sampling
            and max_sample_batches > 0
            and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
        ):
            raise RuntimeError(
                f"Exiting training loop due to hitting dynamic sampling limit for "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
                f"Please check your data difficulty distribution."
            )
        # Update state
        self.dynamic_sampling_state = updated_state

        if not keep_sampling:
            # Reset state when sampling is complete
            self.dynamic_sampling_state = None

        return processed_output, processed_uids, keep_sampling

    def _get_dp_group_models(self, rank: int, model_type: str = ""):
        model = getattr(self, model_type)
        return model._actor_handlers[rank]

    def _get_mesh_rank(self, rank: int, model_type: str = "") -> MeshRank:
        model: PPORayActorGroup = getattr(self, model_type)
        actor_info: ActorInfo = model.actor_infos[rank]
        return actor_info.rank

    def save_checkpoints(self) -> str:
        """
        Save the model, optimizer, and training states to disk. Returns the
        checkpoint folder path.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        # Create global step folder structure
        global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        critic_save_dir = os.path.join(global_step_folder, "critic")

        io.makedirs(global_step_folder, exist_ok=True)

        # Save policy checkpoint (dispatch handles offload/backload)
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save critic checkpoint (if it exists)
        if self.has_critic:
            self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)

        # Save dataloader state
        dataloader_save_path = os.path.join(global_step_folder, "data.pt")
        try:
            dataloader_state_dict = self.train_dataloader.state_dict()
            with io.open_file(dataloader_save_path, "wb") as f:
                torch.save(dataloader_state_dict, f)
            logger.info(f"Saved dataloader state to {dataloader_save_path}")
        except Exception as e:
            logger.warning(f"Failed to save dataloader state: {e}")

        # Save additional trainer state
        trainer_state = {
            "global_step": self.global_step,
            "config": asdict(self.cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking - write this last after all saves succeed
        latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_checkpoint_file, "w") as f:
            f.write(str(self.global_step))

        logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")

        # Clean up old checkpoints after successful save
        with Timer("cleanup_old_checkpoints", self.all_timings):
            self._cleanup_old_checkpoints()

        return global_step_folder

    def _cleanup_old_checkpoints(self):
        if not self._node_ids:
            self._node_ids = self.dispatch.get_node_ids()
        run_on_each_node(
            self._node_ids,
            cleanup_old_checkpoints,
            self.cfg.trainer.ckpt_path,
            self.cfg.trainer.max_ckpts_to_keep,
        )
        # run on driver as well
        # NOTE (sumanthrh): the function will get called twice on the node with driver process, but it's ok because it's idempotent
        cleanup_old_checkpoints(self.cfg.trainer.ckpt_path, self.cfg.trainer.max_ckpts_to_keep)

    def load_checkpoints(self) -> Tuple[int, str]:
        """
        Load complete checkpoint state and return the global_step to resume from.
        Returns 0 if no checkpoint is loaded.

        If colocate_all is True, assumes that the policy model is currently on GPU.

        Returns:
            global_step: The global step to resume from.
            checkpoint_path: The path to the checkpoint.
        """
        checkpoint_path = None
        # Check if resumption is enabled
        if self.resume_mode == ResumeMode.NONE:
            logger.info("Checkpoint resumption disabled, starting training from scratch")
            return 0, None
        # first, let's get resume_path
        elif self.resume_mode == ResumeMode.LATEST:
            latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_checkpoint_file):
                logger.info("No checkpoint found, starting training from scratch")
                return 0, None
            with io.open_file(latest_checkpoint_file, "r") as f:
                ckpt_iteration = int(f.read().strip())
            checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
            # Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
            validate_consistency_for_latest_checkpoint(
                self.cfg.trainer.ckpt_path,
                ckpt_iteration,
                checkpoint_path,
                latest_checkpoint_file,
                self.cfg.trainer.ckpt_interval,
            )
        else:
            # Get and validate resume path
            checkpoint_path = Path(self.cfg.trainer.resume_path)
            if not checkpoint_path:
                raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")

            # Validate that it's a global_step directory
            if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
                raise ValueError(
                    f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
                )

        # Validate that the path exists
        if not io.exists(str(checkpoint_path)):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        logger.info(f"Loading checkpoint from: {checkpoint_path}")

        # Extract global step from checkpoint path
        global_step = extract_step_from_path(Path(checkpoint_path))
        if global_step == -1:
            raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
        logger.info(f"Resuming from global_step: {global_step}")

        # Define paths for different checkpoint components
        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        dataloader_state_path = os.path.join(checkpoint_path, "data.pt")

        # Validate that required checkpoint files exist
        if not io.exists(trainer_state_path):
            raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")

        # 1. Load and validate trainer state
        with io.open_file(trainer_state_path, "rb") as f:
            trainer_state = torch.load(f, map_location="cpu", weights_only=False)
        saved_global_step = trainer_state.get("global_step", global_step)
        logger.info("Successfully loaded trainer state")
        if saved_global_step != global_step:
            logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")

        # 2. Load dataloader state if available
        if io.exists(dataloader_state_path):
            try:
                with io.open_file(dataloader_state_path, "rb") as f:
                    dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
                self.train_dataloader.load_state_dict(dataloader_state)
                logger.info("Successfully loaded dataloader state")
            except Exception as e:
                logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
        else:
            logger.warning(
                f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
            )

        # 3. Load policy checkpoint (dispatch handles offload/backload)
        logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info("Successfully loaded policy checkpoint")

        # 4. Load critic checkpoint if it exists and we have a critic model
        if self.has_critic:
            logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
            self.dispatch.load_checkpoint(
                "critic",
                critic_ckpt_dir,
                load_optimizer_states=True,
                load_lr_scheduler_states=True,
            )
            logger.info("Successfully loaded critic checkpoint")

        logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
        return global_step, str(checkpoint_path)

    def save_models(self):
        """
        Save the model parameters in HF format at `cfg.trainer.export_path`.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        if self.has_critic:
            critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
            self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)

        logger.info("Successfully saved model weights.")

    def update_ref_with_policy(self):
        """
        Update the reference model with the policy model weights (required by some algorithms).

        Dispatch handles offload/backload automatically for all colocation configurations.
        After this method, save_weights_for_sampler() should be called to sync weights.
        """
        # TODO(tgriggs): Make policy-to-ref sync faster.
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")

        # Save policy model (dispatch handles GPU state)
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        # Re-initialize ref model from saved policy (dispatch handles offloading policy first)
        self.dispatch.init_model("ref", policy_export_dir)

        # Clean up temporary saved model files
        try:
            shutil.rmtree(policy_export_dir)
            logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
        except Exception as e:
            logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")

        logger.info("Successfully updated ref model with policy model, training continues.")

attr cfg

cfg = cfg

attr colocate_all

colocate_all = cfg.trainer.placement.colocate_all

attr tracker

tracker = tracker

attr tokenizer

tokenizer = tokenizer

attr train_dataset

train_dataset = train_dataset

attr eval_dataset

eval_dataset = eval_dataset

attr inference_engine_client

inference_engine_client = inference_engine_client

attr generator

generator = generator

attr train_dataloader

train_dataloader = None

attr total_training_steps

total_training_steps = None

attr eval_dataloader

eval_dataloader = build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else None

attr colocate_pg

colocate_pg = colocate_pg

attr resume_mode

resume_mode = ResumeMode(cfg.trainer.resume_mode)

attr all_metrics

all_metrics = {}

attr all_timings

all_timings = {}

attr global_step

global_step = 0

attr policy_model

policy_model: PPORayActorGroup = None

attr critic_model

critic_model: Optional[PPORayActorGroup] = None

attr ref_model

ref_model: Optional[PPORayActorGroup] = None

attr dynamic_sampling_state

dynamic_sampling_state: Optional[DynamicSamplingState] = None

attr reward_kl_controller

reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None

attr dispatch

dispatch: WorkerDispatch = None

method add_callback

add_callback(callback: TrainingCallback) -> None

Register a callback. Events fired after this call reach the new callback.

Source code in skyrl/train/trainer.py:160-162
    def add_callback(self, callback: TrainingCallback) -> None:
        """Register a callback. Events fired after this call reach the new callback."""
        self._callback_handler.add(callback)

attr property has_critic

has_critic: bool

Check if critic model is configured.

method async eval

eval() -> Dict[str, float]

Run generation and scoring on the evaluation dataset.

The eval metrics are recorded after having finished training self.global_step steps. Metrics recorded in global_step 0 corresponds to evaluations before training.

Returns:

TypeDescription
Dict[str, float]A dictionary of evaluation metrics.
Source code in skyrl/train/trainer.py:199-226
    @torch.no_grad()
    async def eval(self) -> Dict[str, float]:
        """
        Run generation and scoring on the evaluation dataset.

        The eval metrics are recorded after having finished training `self.global_step` steps.
        Metrics recorded in global_step 0 corresponds to evaluations before training.

        Returns:
            A dictionary of evaluation metrics.
        """
        if self.cfg.generator.step_wise_trajectories:
            eval_metrics = await evaluate_step_wise(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        else:
            eval_metrics = await evaluate(
                eval_dataloader=self.eval_dataloader,
                generator=self.generator,
                cfg=self.cfg,
                global_step=self.global_step,
                tokenizer=self.tokenizer,
            )
        return eval_metrics

method train

train()

Main training loop for PPO

Source code in skyrl/train/trainer.py:228-488
    async def train(self):
        """
        Main training loop for PPO
        """
        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.start()

        # Initialize weight sync state between policy model and inference engines.
        with Timer("init_weight_sync_state"):
            self.init_weight_sync_state()

        # Load checkpoint state if resumption is enabled.
        if self.resume_mode != ResumeMode.NONE:
            with Timer("load_checkpoints"):
                self.global_step, _ = self.load_checkpoints()

        # Prepare weights for sampling
        with Timer("sync_weights"):
            await self.dispatch.save_weights_for_sampler()

        # Compute start_epoch up-front so callback metadata is ready before
        # any event fires (including the baseline eval below).
        start_epoch = self.global_step // len(self.train_dataloader)
        self._current_epoch = start_epoch
        self._training_control.reset()

        self._fire("on_train_start")

        # Eval before training. Wrapped in eval callbacks + on_log so that e.g.
        # a best-checkpoint callback sees the baseline reading.
        if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
            self._fire("on_eval_start")
            with Timer("eval", self.all_timings):
                eval_metrics = await self.eval()
            self._fire("on_eval_end", metrics=eval_metrics)
            self._fire("on_log", logs=eval_metrics)
            self.tracker.log(eval_metrics, step=self.global_step, commit=True)

        # initialize kl controller
        if self.cfg.trainer.algorithm.use_kl_in_reward:
            self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)

        # main training loop
        pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
        self.global_step += 1  # start training at global_step 1

        # booleans tracking whether we save ckpts
        # as well as hf model at step end
        will_save_ckpts = False
        hf_model_save = False
        for epoch in range(start_epoch, self.cfg.trainer.epochs):
            self._current_epoch = epoch
            self._fire("on_epoch_start")
            # ``step_started`` tracks the on_step_start/on_step_end pairing taking
            # dynamic-sampling into account (which span multiple inner iterations
            # before completing a logical step).
            step_started = False
            for _, rand_prompts in enumerate(self.train_dataloader):
                if not step_started:
                    self._fire("on_step_start")
                    step_started = True
                with Timer("step", self.all_timings):
                    # for colocate_all=true, inference engine is always on GPU when starting the training step

                    # 0. truncate data to have even shards
                    rand_prompts = self._remove_tail_data(rand_prompts)
                    generator_input, uids = prepare_generator_input(
                        rand_prompts,
                        self.cfg.generator.n_samples_per_prompt,
                        get_sampling_params_for_backend(
                            self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
                        ),
                        self.cfg.environment.env_class,
                        "train",
                        self.global_step,
                    )

                    # 1.1. generation phase
                    with Timer("generate", self.all_timings):
                        generator_output: GeneratorOutput = await self.generate(generator_input)

                    if self.cfg.generator.step_wise_trajectories:
                        # NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
                        # this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
                        uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]

                    # dynamic sampling
                    if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
                        generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
                        if keep_sampling:  # continue sampling
                            # update progress bar for current batch (but not global step)
                            pbar.update(1)
                            continue

                    if self.colocate_all:
                        # if we are not continuing sampling, we sleep the inference engine
                        await self.inference_engine_client.sleep()

                    # 1.2 postprocess rewards (and merge step-wise turns if enabled)
                    with Timer("postprocess_generator_output", self.all_timings):
                        generator_output, uids = self.postprocess_generator_output(generator_output, uids)

                    # 2. print example just for debugging
                    log_interval = self.cfg.trainer.log_example_interval
                    if log_interval > 0 and self.global_step % log_interval == 0:
                        vis = self.tokenizer.decode(generator_output["response_ids"][0])
                        log_example(
                            logger,
                            prompt=generator_input["prompts"][0],
                            response=vis,
                            reward=generator_output["rewards"][0],
                        )

                    # 3. Convert GeneratorOutput to TrainingInputBatch
                    with Timer("convert_to_training_input", self.all_timings):
                        training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)

                    # 4. Inference and calculate values, log probs, rewards, kl divergence
                    with Timer("fwd_logprobs_values_reward", self.all_timings):
                        training_input = self.fwd_logprobs_values_reward(training_input)

                    # 5. apply kl divergence penalty to rewards
                    if self.cfg.trainer.algorithm.use_kl_in_reward:
                        with Timer("apply_reward_kl_penalty", self.all_timings):
                            training_input = self.apply_reward_kl_penalty(training_input)

                    # 6. calculate advantages and returns
                    with Timer("compute_advantages_and_returns", self.all_timings):
                        training_input = self.compute_advantages_and_returns(training_input)
                        # remove some unwanted keys
                        for key in ["rewards"]:
                            training_input.pop(key)
                        training_input.metadata.pop("uids")
                        training_input.metadata.pop("is_last_step", None)

                    if self.cfg.trainer.dump_data_batch:
                        # dump data to file
                        with Timer("dump_data_batch"):
                            self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")

                    # 7. train policy/critic model
                    # Policy model is backloaded to GPU during training
                    with Timer("train_critic_and_policy", self.all_timings):
                        status = self.train_critic_and_policy(training_input)

                    self._fire("on_step_end", batch=training_input, metrics=status)
                    step_started = False

                    # Capture callback-driven triggers, then reset.
                    force_save = self._training_control.should_save
                    force_eval = self._training_control.should_evaluate
                    self._training_control.should_save = False
                    self._training_control.should_evaluate = False

                    # 8. conditionally save checkpoints and hf model
                    is_epoch_end = self.global_step % len(self.train_dataloader) == 0
                    hf_model_save = self.cfg.trainer.hf_save_interval > 0 and (
                        is_epoch_end or self.global_step % self.cfg.trainer.hf_save_interval == 0
                    )
                    ckpt_interval_save = self.cfg.trainer.ckpt_interval > 0 and (
                        is_epoch_end or self.global_step % self.cfg.trainer.ckpt_interval == 0
                    )
                    will_save_ckpts = force_save or ckpt_interval_save
                    if will_save_ckpts:
                        with Timer("save_checkpoints", self.all_timings):
                            ckpt_path = self.save_checkpoints()
                        self._fire("on_save", ckpt_path=ckpt_path)
                    if hf_model_save:
                        with Timer("save_hf_model", self.all_timings):
                            self.save_models()

                    # 9. conditionally sync policy and ref at the end of the epoch
                    if (
                        self.cfg.trainer.update_ref_every_epoch
                        and self.ref_model is not None
                        and is_epoch_end
                        and epoch != self.cfg.trainer.epochs - 1  # skip updating ref at the end of the last epoch
                    ):
                        with Timer("update_ref_with_policy", self.all_timings):
                            self.update_ref_with_policy()

                    # 10. Prepare weights for sampling
                    with Timer("sync_weights", self.all_timings):
                        await self.dispatch.save_weights_for_sampler()

                # 11. set logs
                logger.info(status)
                # Throughput metrics
                train_time = self.all_timings.get("train_critic_and_policy", 0.0)
                if train_time > 0 and training_input.get("attention_mask") is not None:
                    total_tokens = int(training_input["attention_mask"].sum().item())
                    self.all_metrics["trainer/tokens_per_second_per_gpu"] = total_tokens / (
                        train_time * self._num_training_gpus
                    )
                # log epoch info
                self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
                interval_eval = self.cfg.trainer.eval_interval > 0 and (
                    self.global_step % self.cfg.trainer.eval_interval == 0
                    or self.global_step == self.total_training_steps
                )
                if force_eval or interval_eval:
                    self._fire("on_eval_start")
                    with Timer("eval", self.all_timings):
                        eval_metrics = await self.eval()
                        self.all_metrics.update(eval_metrics)
                    self._fire("on_eval_end", metrics=eval_metrics)

                log_payload = {
                    **self.all_metrics,
                    **{f"timing/{k}": v for k, v in self.all_timings.items()},
                }
                if self._vllm_metrics_scraper is not None:
                    log_payload.update(await self._vllm_metrics_scraper.sample())

                if self._ray_gpu_monitor is not None:
                    log_payload.update(self._ray_gpu_monitor.flush())

                self._fire("on_log", logs=log_payload)

                self.tracker.log(log_payload, step=self.global_step, commit=True)
                self.all_metrics = {}
                self.all_timings = {}

                # update progress bar after logging
                pbar.update(1)

                self.global_step += 1

                del training_input, generator_output

            self._fire("on_epoch_end")

        pbar.close()
        if self.colocate_all:
            await self.inference_engine_client.sleep()

        # Decrement global step by 1 to stop at the last global step
        # We use the global step value in callbacks when training finishes,
        # as well as for a final checkpoint save
        self.global_step -= 1

        # Safety net: always save final checkpoint at end of training.
        # Skip if we already saved at the last step
        if self.cfg.trainer.ckpt_interval > 0 and not will_save_ckpts:
            with Timer("save_checkpoints", self.all_timings):
                ckpt_path = self.save_checkpoints()
                logger.info("Saved final checkpoint.")
            self._fire("on_save", ckpt_path=ckpt_path)
        if self.cfg.trainer.hf_save_interval > 0 and not hf_model_save:
            with Timer("save_hf_model", self.all_timings):
                self.save_models()
                logger.info("Saved final model.")
        if self._vllm_metrics_scraper is not None:
            await self._vllm_metrics_scraper.aclose()

        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.stop()

        self._fire("on_train_end")
        self.tracker.finish()
        logger.info("Training done!")

method build_models

build_models(PolicyWorker, CriticWorker, RefWorker)

Initialize the actors for training, and handle colocation logic

Source code in skyrl/train/trainer.py:514-716
    def build_models(self, PolicyWorker, CriticWorker, RefWorker):
        """
        Initialize the actors for training, and handle colocation logic
        """
        cfg = self.cfg
        pg = None

        use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward

        if cfg.trainer.placement.colocate_all:
            num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
            num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
            num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
            ie_cfg = cfg.generator.inference_engine
            num_rollout_gpus = (
                ie_cfg.num_engines
                * ie_cfg.tensor_parallel_size
                * ie_cfg.pipeline_parallel_size
                * ie_cfg.data_parallel_size
            )
            assert (
                num_policy_gpus == num_rollout_gpus
            ), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
            pg = self.colocate_pg

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.2 if pg else 1,
                colocate_all=True,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
                record_memory=cfg.trainer.policy.record_memory,
            )
            if use_ref_model:
                assert (
                    num_policy_gpus == num_ref_gpus
                ), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2 if pg else 1,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                assert (
                    num_policy_gpus == num_critic_gpus
                ), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    pg=pg,
                    num_gpus_per_actor=0.2,
                    colocate_all=True,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        else:
            if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
                assert (
                    cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
                    and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
                ), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."

                bundles = [
                    {
                        "GPU": cfg.trainer.placement.policy_num_gpus_per_node,
                        "CPU": cfg.trainer.placement.policy_num_gpus_per_node,
                    }
                    for _ in range(cfg.trainer.placement.policy_num_nodes)
                ]
                raw_pg = placement_group(bundles, strategy="PACK")
                get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
                pg = ResolvedPlacementGroup(raw_pg)

            policy_model = PPORayActorGroup(
                cfg.trainer,
                cfg.trainer.placement.policy_num_nodes,
                cfg.trainer.placement.policy_num_gpus_per_node,
                PolicyWorker,
                pg=pg,
                num_gpus_per_actor=0.75 if pg else 1,
                colocate_all=False,
                sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
            )
            if use_ref_model:
                ref_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.ref_num_nodes,
                    cfg.trainer.placement.ref_num_gpus_per_node,
                    RefWorker,
                    pg=pg,
                    num_gpus_per_actor=0.25 if pg else 1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
                )
                if pg is not None:
                    # The shared policy/ref placement group `pg` is set only when colocate_policy_ref is enabled
                    logger.info(
                        "Colocating policy and ref on the same GPUs across "
                        f"{cfg.trainer.placement.policy_num_nodes} node(s)."
                    )
            else:
                ref_model = None

            if cfg.trainer.critic.model.path:
                critic_model = PPORayActorGroup(
                    cfg.trainer,
                    cfg.trainer.placement.critic_num_nodes,
                    cfg.trainer.placement.critic_num_gpus_per_node,
                    CriticWorker,
                    num_gpus_per_actor=1,
                    colocate_all=False,
                    sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
                )
            else:
                critic_model = None

        policy_steps_per_train_batch = (
            cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
        )
        critic_steps_per_train_batch = 0
        if cfg.trainer.critic.model.path:
            critic_steps_per_train_batch = (
                cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
            )
        policy_num_training_steps = (
            self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
        )
        critic_num_training_steps = (
            self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
        )
        if not cfg.trainer.placement.colocate_all:
            refs = []
            if ref_model is not None:
                refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
            refs.extend(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            if cfg.trainer.critic.model.path:
                refs.extend(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
            ray.get(refs)
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
        else:
            if ref_model is not None:
                ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
                ref_model.offload_to_cpu()
            ray.get(
                policy_model.async_init_model(
                    cfg.trainer.policy.model.path,
                    num_training_steps=policy_num_training_steps,
                )
            )
            ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
            policy_model.offload_to_cpu()
            if cfg.trainer.critic.model.path:
                ray.get(
                    critic_model.async_init_model(
                        cfg.trainer.critic.model.path,
                        num_training_steps=critic_num_training_steps,
                    )
                )
                critic_model.offload_to_cpu()

        self.policy_model: PPORayActorGroup = policy_model
        self.critic_model: Optional[PPORayActorGroup] = critic_model
        self.ref_model: Optional[PPORayActorGroup] = ref_model

        # Create unified dispatch that manages all actor groups
        self.dispatch = WorkerDispatch(
            cfg=self.cfg,
            policy_actor_group=policy_model,
            critic_actor_group=critic_model,
            ref_actor_group=ref_model,
            inference_engine_client=self.inference_engine_client,
        )

        # Mark all models as offloaded if colocate_all (they were offloaded above)
        if self.colocate_all:
            self.dispatch.mark_all_offloaded()

        logger.info("init policy/ref/critic models done")

method init_weight_sync_state

init_weight_sync_state()

Setup the connection between policy model and inference engine for weight syncing.

Source code in skyrl/train/trainer.py:718-723
    def init_weight_sync_state(self):
        """
        Setup the connection between policy model and inference engine for weight syncing.
        """
        self.dispatch.init_weight_sync_state(self.inference_engine_client)
        logger.info("Initialized weight sync state for policy model and inference engines.")

method convert_to_training_input

convert_to_training_input(generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch

Converts lists to a padded batch of tensors for training

Parameters:

NameTypeDescriptionDefault
generator_outputGeneratorOutputGenerated rollouts and associated data.required
uidsList[str]List of prompt-unique identifiers for each generator ouput in the same order as generator_output. Used to identify which prompt each generated rollout belongs to.required

Returns: training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the order of generator_output and hence uids.

Source code in skyrl/train/trainer.py:725-844
    def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
        """Converts lists to a padded batch of tensors for training

        Args:
            generator_output (GeneratorOutput): Generated rollouts and associated data.
            uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same
                order as `generator_output`. Used to identify which prompt each generated rollout belongs to.
        Returns:
            training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
                order of `generator_output` and hence `uids`.
        """
        # 1. Extract generator output fields.
        prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
        response_ids: List[List[int]] = generator_output["response_ids"]
        rewards: List[List[float]] = generator_output["rewards"]
        loss_masks: List[List[int]] = generator_output["loss_masks"]

        logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)
        rollout_expert_indices: Optional[List[List[List[List[int]]]]] = generator_output.get(
            "rollout_expert_indices", None
        )

        pixel_values = generator_output.get("pixel_values", None)
        image_grid_thw = generator_output.get("image_grid_thw", None)
        if pixel_values is not None:
            assert (
                pixel_values is not None and image_grid_thw is not None
            ), "Both pixel_values and image_grid_thw must exist for multi-modal inputs"
            assert len(pixel_values) == len(
                image_grid_thw
            ), "Number of pixel values should match number of image grid thw"
            pixel_values = TensorList(pixel_values)
            image_grid_thw = TensorList(image_grid_thw)

        # 2. Convert to tensors.
        (
            sequences_tensor,
            attention_masks_tensor,
            response_masks_tensor,
            rewards_tensor,
            loss_masks_tensor,
            rollout_logprobs_tensor,
            rollout_expert_indices_tensor,
        ) = convert_prompts_responses_to_batch_tensors(
            self.tokenizer,
            prompt_ids,
            response_ids,
            rewards,
            loss_masks,
            logprobs,
            rollout_expert_indices,
            max_seq_len=self.cfg.trainer.algorithm.max_seq_len,
        )

        # sanity check for off_policy_correction
        off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
        tis_ratio_type = off_policy_correction.tis_ratio_type
        sequence_mask_metric = off_policy_correction.sequence_mask_metric
        if tis_ratio_type is not None or sequence_mask_metric is not None:
            assert (
                rollout_logprobs_tensor is not None
            ), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
            assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"

        # 3. Create training input batch.
        training_input = TrainingInputBatch(
            {
                "sequences": sequences_tensor,  # Full trajectories (padded and concatenated prompts and responses)
                "attention_mask": attention_masks_tensor,
                "response_mask": response_masks_tensor,
                "rewards": rewards_tensor,
                "loss_mask": loss_masks_tensor,
                "rollout_logprobs": rollout_logprobs_tensor,
                "rollout_expert_indices": rollout_expert_indices_tensor,
                "pixel_values": pixel_values,
                "image_grid_thw": image_grid_thw,
            },
        )
        training_input.metadata = {"uids": uids}
        if generator_output.get("is_last_step", None) is not None:
            training_input.metadata["is_last_step"] = generator_output["is_last_step"]

        # 4. Compute mini-batch boundaries for train_critic_and_policy(). It excludes the ones
        # we will add in pad_training_input_batch().
        train_batch_size = self.cfg.trainer.train_batch_size
        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
        is_stepwise = self.cfg.generator.step_wise_trajectories
        training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
            uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
        )
        # Per-prompt boundaries (used by the `prompt_mean` loss reduction). Policy-only,
        # since advantage normalization only applies to the policy.
        training_input.metadata["policy_prompt_boundaries"] = compute_prompt_boundaries(uids)
        if self.cfg.trainer.critic.model.path is not None:
            training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
                uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
            )

        # 5. Record metadata and metrics.
        training_input.metadata["response_length"] = response_masks_tensor.shape[1]
        batch_num_seq, batch_padded_seq_len = sequences_tensor.shape
        logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}")
        self.all_metrics.update(
            {
                "generate/batch_num_seq": batch_num_seq,
                "generate/batch_padded_seq_len": batch_padded_seq_len,
            }
        )
        training_input.metadata["avg_response_length"] = sum(
            len(sample_response_ids) for sample_response_ids in response_ids
        ) / len(response_ids)

        # 6. Pad the batch, only needed for step-wise training's `fwd_logprobs_values_reward()`.
        logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
        dp_size = self.dispatch.get_lcm_dp_size()
        pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
        training_input = pad_training_input_batch(training_input, pad_size)
        logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")

        return training_input

method async generate

generate(input_batch: GeneratorInput) -> GeneratorOutput

Generate rollouts.

If colocate_all is enabled:

  • before calling this method, the policy model should be on CPU and inference engine should be awake (i.e. on GPU).
  • after calling this method, the same model placement still holds.
Source code in skyrl/train/trainer.py:846-873
    @torch.no_grad()
    async def generate(
        self,
        input_batch: GeneratorInput,
    ) -> GeneratorOutput:
        """
        Generate rollouts.

        If colocate_all is enabled:
        - before calling this method, the policy model should be on CPU and inference engine should
            be awake (i.e. on GPU).
        - after calling this method, the same model placement still holds.
        """
        # NOTE: we assume that .generate returns samples in the same order as passed in
        generator_output: GeneratorOutput = await self.generator.generate(input_batch)

        # add rollout metrics to self.all_metrics
        if generator_output["rollout_metrics"] is not None:
            self.all_metrics.update(generator_output["rollout_metrics"])
        generator_output.pop("rollout_metrics", None)

        validate_generator_output(
            len(input_batch["prompts"]),
            generator_output,
            step_wise=self.cfg.generator.step_wise_trajectories,
        )

        return generator_output

method postprocess_generator_output

postprocess_generator_output(generator_output: GeneratorOutput, uids: List[str]) -> Tuple[GeneratorOutput, List[str]]

Converts to per token rewards and computes pass@N.

For step-wise training with merge_stepwise_output=true, also collapses consecutive turns sharing a common prefix into a single sequence; uids is shortened to match.

In the future algorithm specific reward or loss mask post processing should be done here.

Returns:

TypeDescription
Tuple[GeneratorOutput, List[str]](generator_output, uids) — uids may be shorter than the input when merging.
Source code in skyrl/train/trainer.py:875-963
    @torch.no_grad()
    def postprocess_generator_output(
        self, generator_output: GeneratorOutput, uids: List[str]
    ) -> Tuple[GeneratorOutput, List[str]]:
        """
        Converts to per token rewards and computes pass@N.

        For step-wise training with ``merge_stepwise_output=true``, also collapses
        consecutive turns sharing a common prefix into a single sequence; ``uids``
        is shortened to match.

        In the future algorithm specific reward or loss mask post processing should be done here.

        Returns:
            (generator_output, uids) — uids may be shorter than the input when merging.
        """
        generator_output_for_metrics = generator_output
        uids_for_metrics = uids
        if self.cfg.generator.step_wise_trajectories:
            generator_output_for_metrics = defaultdict(list)
            for key in generator_output:
                if isinstance(generator_output[key], list):
                    generator_output_for_metrics[key] = [
                        generator_output[key][i]
                        for i in range(len(generator_output[key]))
                        if generator_output["is_last_step"][i]
                    ]
            uids_for_metrics = [
                uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
            ]

        # only use `generator_output_for_metrics` for metrics calculation
        # For step-wise training, we only calculate metrics for the last step of each trajectory
        overall_metrics = get_metrics_from_generator_output(
            generator_output_for_metrics,
            uids_for_metrics,
        )

        # Prefix-aware merging of step-wise turns.
        if self.cfg.generator.merge_stepwise_output:
            assert self.cfg.generator.step_wise_trajectories, "merge_stepwise_output requires step-wise training"
            num_seq_before_merge = len(generator_output["response_ids"])
            generator_output = merge_stepwise_output(generator_output)
            num_seq_after_merge = len(generator_output["response_ids"])
            logger.info(f"Merged step wise: {num_seq_before_merge} sequences -> {num_seq_after_merge} sequences")
            self.all_metrics.update(
                {
                    "generate/num_seq_before_merge": num_seq_before_merge,
                    "generate/num_seq_after_merge": num_seq_after_merge,
                }
            )
            uids = [tid.instance_id for tid in generator_output["trajectory_ids"]]

        # these use the full generator output
        rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
        responses: List[List[int]] = generator_output["response_ids"]
        per_token_rewards: List[List[float]] = []

        # Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
        if rewards and isinstance(rewards[0], list):
            # Token-level rewards: rewards is List[List[float]]
            per_token_rewards = rewards
        else:
            if self.cfg.trainer.algorithm.zero_variance_filter:
                kept_indices_set = set(zero_variance_filter(rewards, uids))
                generator_output["loss_masks"] = [
                    [0] * len(mask) if i not in kept_indices_set else mask
                    for i, mask in enumerate(generator_output["loss_masks"])
                ]
            # Response-level rewards: rewards is List[float], convert to per-token rewards
            for reward, response in zip(rewards, responses):
                per_token_reward = [0.0] * len(response)
                per_token_reward[-1] = float(reward)
                per_token_rewards.append(per_token_reward)

        n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt

        reward_metrics = {
            f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
            "reward/avg_raw_reward": overall_metrics["avg_score"],
            "reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
        }
        self.all_metrics.update(reward_metrics)
        logger.info(
            f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
        )
        # re-assign reward but now it's per token rewards
        generator_output["rewards"] = per_token_rewards
        return generator_output, uids

method compute_advantages_and_returns

compute_advantages_and_returns(data: TrainingInputBatch) -> TrainingInputBatch

Calculate advantages and returns for the data batch.

Expects:

  • ["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["response_mask"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["loss_mask"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["values"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["rewards"]: Float[torch.Tensor, "batch_size seqlen"]
  • .metadata["uids"]: List[str]
  • .metadata["is_last_step"]: List[bool] for step-wise training

Adds:

  • ["advantages"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["returns"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:965-1076
    @torch.no_grad()
    def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
        """Calculate advantages and returns for the data batch.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `.metadata["uids"]`: List[str]
            - `.metadata["is_last_step"]`: List[bool] for step-wise training

        Adds:
            - `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        token_level_rewards = data["rewards"]

        if self.cfg.generator.step_wise_trajectories:
            is_last_step = torch.tensor(data.metadata["is_last_step"], dtype=torch.bool)
            index = np.array(data.metadata["uids"])
            values = data["values"]
            # Step-wise only supports outcome-based estimators (GRPO, RLOO, MAXRL); ensured by `validate_cfg`.
            # We use the last step of each trajectory to compute advantages and broadcast them to
            # all steps of that trajectory, so we ignore per-step rewards in step-wise training.
            # We pass an all-ones mask here so the estimator returns the scalar advantage at every
            # position. The real per-step `response_mask` is re-applied on broadcast below.
            # Shapes:
            #   traj_ids, (batch_size,):         trajectory id per step (cumsum of shifted is_last_step)
            #   last_step_advantages/returns,
            #       (num_traj, seqlen):          scalar advantage/return per trajectory at every position
            #   last_step_advantages/returns[traj_ids],
            #       (batch_size, seqlen):        broadcast to every step of the owning trajectory
            #   response_mask_float,
            #       (batch_size, seqlen):        per-step response mask
            last_step_response_mask = data["response_mask"][is_last_step]
            last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards[is_last_step],
                response_mask=torch.ones_like(last_step_response_mask, dtype=torch.float),
                index=index[is_last_step.cpu().numpy()],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                values=values[is_last_step] if values is not None else None,
                config=self.cfg.trainer.algorithm,
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
            traj_ids = (
                torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
            )
            num_traj = traj_ids[-1].item() + 1
            assert num_traj == len(
                last_step_advantages
            ), f"num_traj {num_traj} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
            response_mask_float = data["response_mask"].to(last_step_advantages.dtype)
            advantages = last_step_advantages[traj_ids] * response_mask_float
            returns = last_step_returns[traj_ids] * response_mask_float
        else:
            advantages, returns = ppo_utils.compute_advantages_and_returns(
                token_level_rewards=token_level_rewards,
                response_mask=data["response_mask"],
                index=data.metadata["uids"],
                adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
                config=self.cfg.trainer.algorithm,
                values=data["values"],
                gamma=self.cfg.trainer.algorithm.gamma,
                lambd=self.cfg.trainer.algorithm.lambd,
                grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
            )
        data["returns"] = returns
        data["advantages"] = advantages

        # remove padding while calculating metrics
        pad_size = data.metadata.get("pad_size", 0)
        num_samples = len(token_level_rewards)

        return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
        if self.cfg.generator.step_wise_trajectories:
            avg_rewards: float = return_sums[is_last_step[: num_samples - pad_size]].mean().item()
        else:
            avg_rewards: float = return_sums.mean().item()

        avg_response_length = data.metadata["avg_response_length"]
        data = data.to("cpu")

        valid_advantages = torch.masked_select(
            data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
        )
        avg_advantages: float = valid_advantages.mean().item()
        avg_advantages_abs: float = valid_advantages.abs().mean().item()

        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}
        data.metadata["metrics"].update(
            {
                "avg_final_rewards": avg_rewards,
                "avg_response_length": avg_response_length,
                "avg_advantages": avg_advantages,
                "avg_advantages_abs": avg_advantages_abs,
            }
        )

        logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
        self.all_metrics.update(
            {
                "loss/avg_final_rewards": avg_rewards,
                "loss/avg_raw_advantages": avg_advantages,
                "loss/avg_raw_advantages_abs": avg_advantages_abs,
            }
        )
        return data

method dump_data

dump_data(data: TrainingInputBatch, file_name: str)

Dump data to pickle file

Source code in skyrl/train/trainer.py:1078-1084
    def dump_data(self, data: TrainingInputBatch, file_name: str):
        """
        Dump data to pickle file
        """
        data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
        data_save_dir.mkdir(parents=True, exist_ok=True)
        data.save(data_save_dir / f"{file_name}.pkl")

method fwd_logprobs_values_reward

fwd_logprobs_values_reward(training_input: TrainingInputBatch)

Calculate values from the critic, log probs from the policy and ref model.

Dispatch handles offload/backload automatically for all colocation configurations.

Expects:

  • ["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]
  • ["attention_mask"]: Integer[torch.Tensor, "batch_size seqlen"]
  • .metadata["response_length"]: Int

Adds:

  • ["base_action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]
  • ["values"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:1086-1167
    @torch.no_grad()
    def fwd_logprobs_values_reward(
        self,
        training_input: TrainingInputBatch,
    ):
        """
        Calculate values from the critic, log probs from the policy and ref model.

        Dispatch handles offload/backload automatically for all colocation configurations.

        Expects:
            - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
            - `.metadata["response_length"]`: Int

        Adds:
            - `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
            - `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
        """
        fwd_keys = ["sequences", "attention_mask"]
        if training_input.get("rollout_expert_indices") is not None:
            fwd_keys.append("rollout_expert_indices")
        if training_input.get("pixel_values") is not None:
            fwd_keys.append("pixel_values")
        if training_input.get("image_grid_thw") is not None:
            fwd_keys.append("image_grid_thw")
        data_fwd_pass = training_input.select(keys=fwd_keys, metadata_keys=["response_length"])

        values = None
        base_log_probs = None
        action_log_probs = None

        # Critic forward (dispatch handles offload/backload automatically)
        if self.has_critic:
            critic_output = self.dispatch.forward("critic", data_fwd_pass)
            values = loss_fn_outputs_to_tensor(critic_output.loss_fn_outputs, key="values")

        # Ref forward
        if self.ref_model is not None:
            ref_output = self.dispatch.forward("ref", data_fwd_pass)
            base_log_probs = loss_fn_outputs_to_tensor(ref_output.loss_fn_outputs, key="logprobs")
            self.dispatch.empty_cache("ref")

        # Policy forward
        policy_output = self.dispatch.forward("policy", data_fwd_pass)
        action_log_probs = loss_fn_outputs_to_tensor(policy_output.loss_fn_outputs, key="logprobs")

        # Empty cache after all forward passes
        self.dispatch.empty_cache()

        sequences_all: torch.Tensor = training_input["sequences"]
        # NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
        base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
        action_log_probs = action_log_probs[: len(sequences_all)]
        values = values[: len(sequences_all)] if values is not None else None

        training_input["base_action_log_probs"] = base_log_probs
        training_input["action_log_probs"] = action_log_probs
        training_input["values"] = values

        if training_input.get("rollout_logprobs", None) is not None:
            # calculates the difference in probs between inference and trainer components
            # only consider response tokens
            logprobs_diff = (
                training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
                - action_log_probs[training_input["loss_mask"] > 0]
            ).abs()

            logprobs_diff_max = logprobs_diff.max().item()
            logprobs_diff_min = logprobs_diff.min().item()
            logprobs_diff_mean = logprobs_diff.mean().item()
            logprobs_diff_std = logprobs_diff.std().item()
            self.all_metrics.update(
                {
                    "policy/rollout_train_logprobs_abs_diff_max": logprobs_diff_max,
                    "policy/rollout_train_logprobs_abs_diff_min": logprobs_diff_min,
                    "policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
                    "policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
                }
            )
        return training_input

method apply_reward_kl_penalty

apply_reward_kl_penalty(data: TrainingInputBatch) -> TrainingInputBatch

Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards.

Source code in skyrl/train/trainer.py:1169-1224
    def apply_reward_kl_penalty(
        self,
        data: TrainingInputBatch,
    ) -> TrainingInputBatch:
        """Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
        loss_masks_all: torch.Tensor = data["loss_mask"]
        rewards: torch.Tensor = data["rewards"]
        base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
        action_log_probs: torch.Tensor = data["action_log_probs"]

        # single batched computation
        with torch.no_grad():
            kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl(  # type: ignore
                action_log_probs,
                base_action_log_probs,
                loss_mask=loss_masks_all,
                kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
            )
        kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0]  # noqa: F821
        kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1)  # noqa: F821

        # NOTE (erictang000): only supporting custom rewards currently
        kl_loss_coef = (
            self.reward_kl_controller.value
            if self.reward_kl_controller is not None
            else self.cfg.trainer.algorithm.kl_loss_coef
        )
        rewards = rewards - kl * max(0, kl_loss_coef)
        data["rewards"] = rewards

        avg_kl: float = kl_mean.mean().item()
        avg_kl_max: float = kl_max.mean().item()

        # update the kl controller
        if self.reward_kl_controller is not None:
            self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0])  # n_steps is just the batch size
        if "metrics" not in data.metadata:
            data.metadata["metrics"] = {}

        data.metadata["metrics"].update(
            {
                "avg_kl": avg_kl,
                "avg_kl_max": avg_kl_max,
                "kl_loss_coef": kl_loss_coef,
            }
        )

        self.all_metrics.update(
            {
                "loss/avg_kl": avg_kl,
                "loss/avg_kl_max": avg_kl_max,
                "loss/kl_loss_coef": kl_loss_coef,
            }
        )

        return data

method train_critic_and_policy

train_critic_and_policy(data: TrainingInputBatch)

Run the training step for the policy and critic models.

Uses forward_backward + optim_step for both FSDP and Megatron strategies.

Source code in skyrl/train/trainer.py:1315-1341
    def train_critic_and_policy(self, data: TrainingInputBatch):
        """
        Run the training step for the policy and critic models.

        Uses forward_backward + optim_step for both FSDP and Megatron strategies.
        """
        data.metadata["global_step"] = self.global_step
        critic_status = None

        # Unified training interface for both FSDP and Megatron
        if self.has_critic:
            with Timer("critic_train", self.all_timings):
                critic_status = self._execute_training_step("critic", data)
        with Timer("policy_train", self.all_timings):
            policy_status = self._execute_training_step("policy", data)

        # Update metrics
        if critic_status is not None:
            for k, v in critic_status.items():
                self.all_metrics.update({f"critic/{k}": v})

        for k, v in policy_status.items():
            self.all_metrics.update({f"policy/{k}": v})

        self.dispatch.empty_cache()

        return policy_status

method handle_dynamic_sampling

handle_dynamic_sampling(generator_output: GeneratorOutput, uids: List[str]) -> Tuple[GeneratorOutput, List[str], bool]

Handle dynamic sampling for the current batch.

Accumulates the generator output and UIDs across batches if we are sampling repeatedly and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch. If we hit the limit of max sample batches, we raise an error.

Parameters:

NameTypeDescriptionDefault
generator_outputGeneratorOutputCurrent batch generator outputrequired
uidsList[str]Current batch UIDsrequired

Returns:

NameTypeDescription
processed_outputGeneratorOutputFiltered generator output
processed_uidsList[str]Filtered UIDs
keep_samplingboolWhether to keep sampling
Source code in skyrl/train/trainer.py:1343-1403
    def handle_dynamic_sampling(
        self, generator_output: GeneratorOutput, uids: List[str]
    ) -> Tuple[GeneratorOutput, List[str], bool]:
        """
        Handle dynamic sampling for the current batch.

        Accumulates the generator output and UIDs across batches if we are sampling repeatedly
        and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
        If we hit the limit of max sample batches, we raise an error.

        Args:
            generator_output: Current batch generator output
            uids: Current batch UIDs

        Returns:
            processed_output: Filtered generator output
            processed_uids: Filtered UIDs
            keep_sampling: Whether to keep sampling
        """
        # Prepare sampling configuration
        max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
        dynamic_sampling_config = {
            "type": self.cfg.trainer.algorithm.dynamic_sampling.type,
            "max_sample_batches": max_sample_batches,
            "min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
            "train_batch_size": self.cfg.trainer.train_batch_size,
            "n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
        }

        if self.dynamic_sampling_state is None:
            self.dynamic_sampling_state: DynamicSamplingState = {
                "sample_batch_count": 1,
            }
        else:
            self.dynamic_sampling_state["sample_batch_count"] += 1

        # Handle dynamic sampling using utilities
        processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
            generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
        )

        # Check max resample limit, and if we hit it, raise an error
        if (
            keep_sampling
            and max_sample_batches > 0
            and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
        ):
            raise RuntimeError(
                f"Exiting training loop due to hitting dynamic sampling limit for "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
                f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
                f"Please check your data difficulty distribution."
            )
        # Update state
        self.dynamic_sampling_state = updated_state

        if not keep_sampling:
            # Reset state when sampling is complete
            self.dynamic_sampling_state = None

        return processed_output, processed_uids, keep_sampling

method save_checkpoints

save_checkpoints() -> str

Save the model, optimizer, and training states to disk. Returns the checkpoint folder path.

Dispatch handles offload/backload automatically for all colocation configurations.

Source code in skyrl/train/trainer.py:1414-1466
    def save_checkpoints(self) -> str:
        """
        Save the model, optimizer, and training states to disk. Returns the
        checkpoint folder path.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        # Create global step folder structure
        global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        critic_save_dir = os.path.join(global_step_folder, "critic")

        io.makedirs(global_step_folder, exist_ok=True)

        # Save policy checkpoint (dispatch handles offload/backload)
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save critic checkpoint (if it exists)
        if self.has_critic:
            self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)

        # Save dataloader state
        dataloader_save_path = os.path.join(global_step_folder, "data.pt")
        try:
            dataloader_state_dict = self.train_dataloader.state_dict()
            with io.open_file(dataloader_save_path, "wb") as f:
                torch.save(dataloader_state_dict, f)
            logger.info(f"Saved dataloader state to {dataloader_save_path}")
        except Exception as e:
            logger.warning(f"Failed to save dataloader state: {e}")

        # Save additional trainer state
        trainer_state = {
            "global_step": self.global_step,
            "config": asdict(self.cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking - write this last after all saves succeed
        latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_checkpoint_file, "w") as f:
            f.write(str(self.global_step))

        logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")

        # Clean up old checkpoints after successful save
        with Timer("cleanup_old_checkpoints", self.all_timings):
            self._cleanup_old_checkpoints()

        return global_step_folder

method load_checkpoints

load_checkpoints() -> Tuple[int, str]

Load complete checkpoint state and return the global_step to resume from. Returns 0 if no checkpoint is loaded.

If colocate_all is True, assumes that the policy model is currently on GPU.

Returns:

NameTypeDescription
global_stepintThe global step to resume from.
checkpoint_pathstrThe path to the checkpoint.
Source code in skyrl/train/trainer.py:1481-1592
    def load_checkpoints(self) -> Tuple[int, str]:
        """
        Load complete checkpoint state and return the global_step to resume from.
        Returns 0 if no checkpoint is loaded.

        If colocate_all is True, assumes that the policy model is currently on GPU.

        Returns:
            global_step: The global step to resume from.
            checkpoint_path: The path to the checkpoint.
        """
        checkpoint_path = None
        # Check if resumption is enabled
        if self.resume_mode == ResumeMode.NONE:
            logger.info("Checkpoint resumption disabled, starting training from scratch")
            return 0, None
        # first, let's get resume_path
        elif self.resume_mode == ResumeMode.LATEST:
            latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_checkpoint_file):
                logger.info("No checkpoint found, starting training from scratch")
                return 0, None
            with io.open_file(latest_checkpoint_file, "r") as f:
                ckpt_iteration = int(f.read().strip())
            checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
            # Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
            validate_consistency_for_latest_checkpoint(
                self.cfg.trainer.ckpt_path,
                ckpt_iteration,
                checkpoint_path,
                latest_checkpoint_file,
                self.cfg.trainer.ckpt_interval,
            )
        else:
            # Get and validate resume path
            checkpoint_path = Path(self.cfg.trainer.resume_path)
            if not checkpoint_path:
                raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")

            # Validate that it's a global_step directory
            if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
                raise ValueError(
                    f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
                )

        # Validate that the path exists
        if not io.exists(str(checkpoint_path)):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        logger.info(f"Loading checkpoint from: {checkpoint_path}")

        # Extract global step from checkpoint path
        global_step = extract_step_from_path(Path(checkpoint_path))
        if global_step == -1:
            raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
        logger.info(f"Resuming from global_step: {global_step}")

        # Define paths for different checkpoint components
        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        dataloader_state_path = os.path.join(checkpoint_path, "data.pt")

        # Validate that required checkpoint files exist
        if not io.exists(trainer_state_path):
            raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")

        # 1. Load and validate trainer state
        with io.open_file(trainer_state_path, "rb") as f:
            trainer_state = torch.load(f, map_location="cpu", weights_only=False)
        saved_global_step = trainer_state.get("global_step", global_step)
        logger.info("Successfully loaded trainer state")
        if saved_global_step != global_step:
            logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")

        # 2. Load dataloader state if available
        if io.exists(dataloader_state_path):
            try:
                with io.open_file(dataloader_state_path, "rb") as f:
                    dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
                self.train_dataloader.load_state_dict(dataloader_state)
                logger.info("Successfully loaded dataloader state")
            except Exception as e:
                logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
        else:
            logger.warning(
                f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
            )

        # 3. Load policy checkpoint (dispatch handles offload/backload)
        logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info("Successfully loaded policy checkpoint")

        # 4. Load critic checkpoint if it exists and we have a critic model
        if self.has_critic:
            logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
            self.dispatch.load_checkpoint(
                "critic",
                critic_ckpt_dir,
                load_optimizer_states=True,
                load_lr_scheduler_states=True,
            )
            logger.info("Successfully loaded critic checkpoint")

        logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
        return global_step, str(checkpoint_path)

method save_models

save_models()

Save the model parameters in HF format at cfg.trainer.export_path.

Dispatch handles offload/backload automatically for all colocation configurations.

Source code in skyrl/train/trainer.py:1594-1607
    def save_models(self):
        """
        Save the model parameters in HF format at `cfg.trainer.export_path`.

        Dispatch handles offload/backload automatically for all colocation configurations.
        """
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        if self.has_critic:
            critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
            self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)

        logger.info("Successfully saved model weights.")

method update_ref_with_policy

update_ref_with_policy()

Update the reference model with the policy model weights (required by some algorithms).

Dispatch handles offload/backload automatically for all colocation configurations. After this method, save_weights_for_sampler() should be called to sync weights.

Source code in skyrl/train/trainer.py:1609-1632
    def update_ref_with_policy(self):
        """
        Update the reference model with the policy model weights (required by some algorithms).

        Dispatch handles offload/backload automatically for all colocation configurations.
        After this method, save_weights_for_sampler() should be called to sync weights.
        """
        # TODO(tgriggs): Make policy-to-ref sync faster.
        policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")

        # Save policy model (dispatch handles GPU state)
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)

        # Re-initialize ref model from saved policy (dispatch handles offloading policy first)
        self.dispatch.init_model("ref", policy_export_dir)

        # Clean up temporary saved model files
        try:
            shutil.rmtree(policy_export_dir)
            logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
        except Exception as e:
            logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")

        logger.info("Successfully updated ref model with policy model, training continues.")

Dispatch APIs

class Dispatch

Bases: ABC

Base class for dispatch types

Dispatch types are responsible for:

  • dispatching method calls to actors handling data sharding if necessary
  • validating arguments for dispatch

Functions:

NameDescription
dispatchDispatches method calls to the actors with data sharding if necessary.
validate_dispatch_argsValidate and process arguments for dispatch.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:62-84
class Dispatch(ABC):
    """Base class for dispatch types

    Dispatch types are responsible for:
    - dispatching method calls to actors handling data sharding if necessary
    - validating arguments for dispatch
    """

    @classmethod
    @abstractmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        """Dispatches method calls to the actors with data sharding if necessary."""
        pass

    @classmethod
    @abstractmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        """Validate and process arguments for dispatch.

        Returns:
            Tuple of (args, kwargs) to be passed to dispatch
        """
        pass

attr dispatch

dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]

Dispatches method calls to the actors with data sharding if necessary.

Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:70-74
    @classmethod
    @abstractmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        """Dispatches method calls to the actors with data sharding if necessary."""
        pass

method abstractmethod classmethod validate_dispatch_args

validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]

Validate and process arguments for dispatch.

Returns:

TypeDescription
Tuple[Tuple, Dict[str, Any]]Tuple of (args, kwargs) to be passed to dispatch
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:76-84
    @classmethod
    @abstractmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        """Validate and process arguments for dispatch.

        Returns:
            Tuple of (args, kwargs) to be passed to dispatch
        """
        pass

class MeshDispatch

Bases: Dispatch

Mesh dispatch type to dispatch data to a group of actors along the device mesh.

Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism. The actor method should accept a single argument - the data batch.

For data dispatch:

  • The input data is chunked into dp_size equal chunks, where dp_size is the size of data parallelism.
  • Each actor with the same DP rank processes the same data chunk in parallel.

Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:

  • Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk, and all actors with DP rank 1 process the second chunk.

Functions:

NameDescription
dispatch
stage_chunksPre-stage mini-batch chunks into the object store.
dispatch_from_stagedDispatch pre-staged per-DP chunks to workers.
validate_dispatch_args
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:87-211
class MeshDispatch(Dispatch):
    """Mesh dispatch type to dispatch data to a group of actors along the device mesh.

    Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism.
    The actor method should accept a single argument - the data batch.

    For data dispatch:

    * The input data is chunked into `dp_size` equal chunks, where `dp_size` is the size of data parallelism.
    * Each actor with the same DP rank processes the same data chunk in parallel.

    Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:

    * Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk,
      and all actors with DP rank 1 process the second chunk.
    """

    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        dp_size = actor_infos[0].rank.dp_size
        assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
            len(data), dp_size
        )
        chunk_size = len(data) // dp_size
        data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)

        # Put each unique chunk in object store ONCE to avoid redundant serialization
        # when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
        chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]

        for actor_info in actor_infos:
            # Pass ObjectRef instead of data - workers will fetch from object store
            chunk_ref = chunk_refs[actor_info.rank.dp]
            object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
        return object_refs

    @classmethod
    def stage_chunks(
        cls,
        dp_size: int,
        data: TrainingInputBatch,
        mini_batch_boundaries: List[Tuple[int, int]],
    ) -> List[List[ObjectRef]]:
        """Pre-stage mini-batch chunks into the object store.

        Each mini-batch is defined by a ``(start, end)`` index pair from mini_batch_boundaries.
        Mini-batches are individually padded so that their size is divisible by dp_size, using dummy
        entries with ``loss_mask=0`` that do not affect the loss.

        Args:
            dp_size: Number of data-parallel ranks.
            data: Full TrainingInputBatch to slice from.
            mini_batch_boundaries: List of ``(start, end)`` index pairs.  The i-th mini-batch is
                data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]].

        Returns:
            ``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*.
        """
        all_chunk_refs: List[List[ObjectRef]] = []
        for start, end in mini_batch_boundaries:
            mini_batch = data[start:end]
            mb_size = end - start

            # Pad to make divisible by dp_size. Will only be non-zero for step-wise training.
            pad_size = (-mb_size) % dp_size
            if pad_size > 0:
                mini_batch = pad_training_input_batch(mini_batch, pad_size)

            mini_batch_size = len(mini_batch)
            assert (
                mini_batch_size % dp_size == 0
            ), f"mini_batch_size % dp_size != 0, got {mini_batch_size} and {dp_size}"
            chunk_size = mini_batch_size // dp_size
            chunks = mini_batch.chunk(chunk_size)
            all_chunk_refs.append([ray.put(chunk) for chunk in chunks])
        return all_chunk_refs

    @classmethod
    def dispatch_from_staged(
        cls,
        actor_infos: List[ActorInfo],
        method: str,
        chunk_refs: List[ObjectRef],
        **kwargs,
    ) -> List[ObjectRef]:
        """
        Dispatch pre-staged per-DP chunks to workers.

        Each worker receives only its own chunk (already in the object
        store), avoiding unnecessary deserialization overhead.

        Args:
            actor_infos: List of actor info objects
            method: Name of method to call on workers (receives a single data chunk)
            chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_chunks``)
            **kwargs: Additional keyword arguments to pass to the method

        Returns:
            List of ObjectRefs for worker results
        """
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        for actor_info in actor_infos:
            chunk_ref = chunk_refs[actor_info.rank.dp]
            object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
        return object_refs

    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # Extract data from either positional arg or kwarg
        if args:
            data = args[0]
            remaining_kwargs = kwargs
        elif "data" in kwargs:
            data = kwargs.pop("data")
            remaining_kwargs = kwargs
        else:
            raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")

        if not isinstance(data, TrainingInputBatch):
            raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
        # Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
        return (data,), remaining_kwargs

attr dispatch

dispatch(actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs: TrainingInputBatch) -> List[ObjectRef]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:104-123
    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        dp_size = actor_infos[0].rank.dp_size
        assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
            len(data), dp_size
        )
        chunk_size = len(data) // dp_size
        data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)

        # Put each unique chunk in object store ONCE to avoid redundant serialization
        # when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
        chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]

        for actor_info in actor_infos:
            # Pass ObjectRef instead of data - workers will fetch from object store
            chunk_ref = chunk_refs[actor_info.rank.dp]
            object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
        return object_refs

method classmethod stage_chunks

stage_chunks(dp_size: int, data: TrainingInputBatch, mini_batch_boundaries: List[Tuple[int, int]]) -> List[List[ObjectRef]]

Pre-stage mini-batch chunks into the object store.

Each mini-batch is defined by a (start, end) index pair from mini_batch_boundaries. Mini-batches are individually padded so that their size is divisible by dp_size, using dummy entries with loss_mask=0 that do not affect the loss.

Parameters:

NameTypeDescriptionDefault
dp_sizeintNumber of data-parallel ranks.required
dataTrainingInputBatchFull TrainingInputBatch to slice from.required
mini_batch_boundariesList[Tuple[int, int]]List of (start, end) index pairs. The i-th mini-batch is data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]].required

Returns:

TypeDescription
List[List[ObjectRef]]result[i][dp_rank] - ObjectRef for mini-batch i, DP rank dp_rank.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:125-164
    @classmethod
    def stage_chunks(
        cls,
        dp_size: int,
        data: TrainingInputBatch,
        mini_batch_boundaries: List[Tuple[int, int]],
    ) -> List[List[ObjectRef]]:
        """Pre-stage mini-batch chunks into the object store.

        Each mini-batch is defined by a ``(start, end)`` index pair from mini_batch_boundaries.
        Mini-batches are individually padded so that their size is divisible by dp_size, using dummy
        entries with ``loss_mask=0`` that do not affect the loss.

        Args:
            dp_size: Number of data-parallel ranks.
            data: Full TrainingInputBatch to slice from.
            mini_batch_boundaries: List of ``(start, end)`` index pairs.  The i-th mini-batch is
                data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]].

        Returns:
            ``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*.
        """
        all_chunk_refs: List[List[ObjectRef]] = []
        for start, end in mini_batch_boundaries:
            mini_batch = data[start:end]
            mb_size = end - start

            # Pad to make divisible by dp_size. Will only be non-zero for step-wise training.
            pad_size = (-mb_size) % dp_size
            if pad_size > 0:
                mini_batch = pad_training_input_batch(mini_batch, pad_size)

            mini_batch_size = len(mini_batch)
            assert (
                mini_batch_size % dp_size == 0
            ), f"mini_batch_size % dp_size != 0, got {mini_batch_size} and {dp_size}"
            chunk_size = mini_batch_size // dp_size
            chunks = mini_batch.chunk(chunk_size)
            all_chunk_refs.append([ray.put(chunk) for chunk in chunks])
        return all_chunk_refs

method classmethod dispatch_from_staged

dispatch_from_staged(actor_infos: List[ActorInfo], method: str, chunk_refs: List[ObjectRef], **kwargs: List[ObjectRef]) -> List[ObjectRef]

Dispatch pre-staged per-DP chunks to workers.

Each worker receives only its own chunk (already in the object store), avoiding unnecessary deserialization overhead.

Parameters:

NameTypeDescriptionDefault
actor_infosList[ActorInfo]List of actor info objectsrequired
methodstrName of method to call on workers (receives a single data chunk)required
chunk_refsList[ObjectRef]Pre-staged ObjectRefs, one per DP rank (from stage_chunks)required
**kwargsAdditional keyword arguments to pass to the method{}

Returns:

TypeDescription
List[ObjectRef]List of ObjectRefs for worker results
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:166-194
    @classmethod
    def dispatch_from_staged(
        cls,
        actor_infos: List[ActorInfo],
        method: str,
        chunk_refs: List[ObjectRef],
        **kwargs,
    ) -> List[ObjectRef]:
        """
        Dispatch pre-staged per-DP chunks to workers.

        Each worker receives only its own chunk (already in the object
        store), avoiding unnecessary deserialization overhead.

        Args:
            actor_infos: List of actor info objects
            method: Name of method to call on workers (receives a single data chunk)
            chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_chunks``)
            **kwargs: Additional keyword arguments to pass to the method

        Returns:
            List of ObjectRefs for worker results
        """
        assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
        object_refs = []
        for actor_info in actor_infos:
            chunk_ref = chunk_refs[actor_info.rank.dp]
            object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
        return object_refs

method abstractmethod classmethod validate_dispatch_args

validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:196-211
    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # Extract data from either positional arg or kwarg
        if args:
            data = args[0]
            remaining_kwargs = kwargs
        elif "data" in kwargs:
            data = kwargs.pop("data")
            remaining_kwargs = kwargs
        else:
            raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")

        if not isinstance(data, TrainingInputBatch):
            raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
        # Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
        return (data,), remaining_kwargs

class PassThroughDispatch

Bases: Dispatch

PassThrough dispatch type to dispatch data to a group of actors without any sharding.

This is useful for cases where we want to run the same method on all the actors. Supports methods with any number of arguments.

Functions:

Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:214-228
class PassThroughDispatch(Dispatch):
    """PassThrough dispatch type to dispatch data to a group of actors without any sharding.

    This is useful for cases where we want to run the same method on all the actors.
    Supports methods with any number of arguments.
    """

    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]

    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # no validation needed just pass everything
        return args, kwargs

attr dispatch

dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:221-223
    @classmethod
    def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
        return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]

method abstractmethod classmethod validate_dispatch_args

validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:225-228
    @classmethod
    def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
        # no validation needed just pass everything
        return args, kwargs

Worker APIs

The base worker abstraction in SkyRL.

class Worker

Worker(cfg: TrainerConfig, *args: TrainerConfig, **kwargs: TrainerConfig)

Bases: DistributedTorchRayActor

Functions:

NameDescription
get_node_local_rank
init_worker_process_group
get_mesh_rank
get_gpu_id
get_ray_node_id
get_master_addr_port
init_modelInitialize worker state (model, and optimizer if applicable) on worker.
empty_cacheEmpty GPU memory cache on Worker's CUDA device
set_algorithm_config
offload_to_cpuOffload all worker state to CPU.
backload_to_gpuBackload worker state to GPU.
get_cuda_memoryGet CUDA memory usage on worker's CUDA device.
save_memory_snapshotSave a snapshot of memory usage on the Worker's CUDA device.
init_weight_sync_stateInitialize state for weight syncing with Inference Engine Client
forwardRun forward pass on the input batch.
save_checkpoint
load_checkpoint
save_hf_model
get_lrGet current learning rate from optimizer. Returns None when the worker was
set_lrSet learning rate for the optimizer.

Attributes:

Source code in skyrl/backends/skyrl_train/workers/worker.py:231-454
class Worker(DistributedTorchRayActor):
    def __init__(self, cfg: TrainerConfig, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cfg = cfg
        self._transfer_strategy_cls = None  # Set in init_weight_transfer_communicator

        if self.cfg.algorithm.temperature is None:
            raise ValueError("`cfg.algorithm.temperature` must be set")

    def init_model(self, *args, **kwargs):
        """Initialize worker state (model, and optimizer if applicable) on worker."""
        raise NotImplementedError()

    def empty_cache(self) -> None:
        """Empty GPU memory cache on Worker's CUDA device"""
        torch.cuda.empty_cache()

    def set_algorithm_config(self, **kwargs) -> None:
        for key, value in kwargs.items():
            setattr(self.cfg.algorithm, key, value)

    def _get_module_for_offload(self):
        """Return the model module(s) to be offloaded/backloaded. Megatron offloads `self.actor_module`. FSDP workers use `self.model` directly."""
        return self.model

    def offload_to_cpu(self, offload_optimizer=True, offload_model=True):
        """Offload all worker state to CPU.

        After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain.

        Args:
            offload_optimizer: Whether to offload optimizer state (no-op when there is no optimizer, e.g. Ref worker).
            offload_model: Whether to offload model parameters.
        """
        self._set_numa_affinity(torch.distributed.get_rank() % torch.cuda.device_count())
        self.strategy.offload_to_cpu(
            self._get_module_for_offload(),
            self.optimizer,
            offload_optimizer=offload_optimizer,
            offload_model=offload_model,
        )

    def backload_to_gpu(self, backload_optimizer=True, backload_model=True):
        """Backload worker state to GPU.

        Args:
            backload_optimizer: Whether to backload optimizer state (no-op when there is no optimizer).
            backload_model: Whether to backload model parameters.
        """
        self.strategy.backload_to_gpu(
            self._get_module_for_offload(),
            self.optimizer,
            backload_optimizer=backload_optimizer,
            backload_model=backload_model,
        )

    def get_cuda_memory(self) -> Dict[str, Any]:
        """Get CUDA memory usage on worker's CUDA device."""
        torch.cuda.synchronize()
        free, total = torch.cuda.mem_get_info()
        return {
            "allocated": torch.cuda.memory_allocated(),
            "reserved": torch.cuda.memory_reserved(),
            "free": free,
            "total": total,
        }

    def save_memory_snapshot(self, tag: str = ""):
        """Save a snapshot of memory usage on the Worker's CUDA device.

        No-ops if record_memory is False.

        Args:
            tag: Label for the snapshot (e.g., "forward_backward", "optim_step")

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        if not self.record_memory:
            return

        # Track snapshot count for unique filenames
        if not hasattr(self, "_snapshot_count"):
            self._snapshot_count = 0
        self._snapshot_count += 1

        rank = torch.distributed.get_rank()
        save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
        if self._local_rank == 0 and not io.exists(save_path):
            io.makedirs(save_path, exist_ok=True)
        torch.distributed.barrier()

        tag_str = f"_{tag}" if tag else ""
        file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
        record_memory_path = os.path.join(save_path, file_name)
        if io.exists(record_memory_path):
            # seeing issues if we don't remove the file first
            io.remove(record_memory_path)
        torch.cuda.memory._dump_snapshot(record_memory_path)

    async def init_weight_sync_state(
        self,
        inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
        inference_engine_cfg: "InferenceEngineConfig",
    ):
        """Initialize state for weight syncing with Inference Engine Client

        Creates init info and sender, then sends init info to inference engines
        so they can create receivers.

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls

        assert inference_engine_client is not None

        # Determine transfer strategy based on inference engine config and placement
        self._transfer_strategy_cls = get_transfer_strategy_cls(
            weight_sync_backend=inference_engine_cfg.weight_sync_backend,
            colocate_all=self.cfg.placement.colocate_all,
        )

        # For new inference path, fetch world_size from servers
        # For legacy path, calculate from config
        inference_world_size = None
        if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
            inference_world_size, _ = await inference_engine_client.get_world_size()

        # Create init info on all ranks (it's deterministic from cfg or fetched world_size)
        init_info = self._transfer_strategy_cls.create_init_info(
            inference_engine_cfg, inference_world_size=inference_world_size
        )

        # Create sender on all ranks
        # Strategy implementations may have different logic for different ranks
        tasks = [
            asyncio.to_thread(
                self._transfer_strategy_cls.create_sender,
                init_info=init_info,
                inference_client=inference_engine_client,
            ),
        ]

        # Only rank 0 initializes receivers on inference engines
        # NOTE: For broadcast strategy, sender and receiver init must run concurrently
        # because both need to join the same process group to avoid deadlock
        if torch.distributed.get_rank() == 0:
            tasks.append(inference_engine_client.init_weight_update_communicator(init_info))

        results = await asyncio.gather(*tasks)
        self._weight_transfer_sender = results[0]  # sender is always first task

        # # Register signal handlers for termination only on rank 0
        # NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
        # The better way is to just have this specified in __del__, but there is
        # no guarattee that __del__ will be called in general. Ray also doesn't
        # explictly call __del__ when the actor shuts down.
        # It's commented out so that we can fix this in the future.
        # atexit.register(self._handle_termination)

        torch.distributed.barrier()

    def forward(self, *args, **kwargs) -> WorkerOutput:
        """Run forward pass on the input batch.

        Each worker subclass declares its own concrete signature and returns a
        :class:`WorkerOutput` so callers can program against a uniform API.
        """
        raise NotImplementedError()

    def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutputBatch:
        raise NotImplementedError()

    def save_checkpoint(self, ckpt_dir: str, tokenizer=None):
        self.strategy.save_checkpoint(
            model=self.model,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            ckpt_dir=ckpt_dir,
            node_local_rank=self.get_node_local_rank(),
            tokenizer=tokenizer,
        )

    def load_checkpoint(self, ckpt_dir: str, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True):
        _, states = self.strategy.load_checkpoint(
            model=self.model,
            optimizer=self.optimizer if load_optimizer_states else None,
            scheduler=self.scheduler if load_lr_scheduler_states else None,
            ckpt_dir=ckpt_dir,
            load_optimizer_states=load_optimizer_states,
            load_lr_scheduler_states=load_lr_scheduler_states,
        )
        return states

    def save_hf_model(self, export_dir: str, tokenizer):
        # Save model in HuggingFace safetensors format
        self.strategy.save_hf_model(
            self.model,
            export_dir,
            tokenizer=tokenizer,
        )

    def get_lr(self) -> Optional[float]:
        """
        Get current learning rate from optimizer. Returns None when the worker was
        initialized with ``policy.inference_only_init=True`` (no optimizer constructed).
        """
        if self.optimizer is None:
            return None
        return self.optimizer.param_groups[0]["lr"]

    def set_lr(self, learning_rate: float) -> None:
        """
        Set learning rate for the optimizer.

        This directly updates the optimizer's param_groups, bypassing the scheduler.
        Useful for external learning rate schedules (e.g., from Tinker). No-op when
        ``policy.inference_only_init=True`` (no optimizer constructed).
        """
        if self.optimizer is None:
            return
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = learning_rate

attr sequence_parallel_size

sequence_parallel_size: int = sequence_parallel_size

attr record_memory

record_memory = record_memory

get_node_local_rank

get_node_local_rank()

init_worker_process_group

init_worker_process_group()

get_mesh_rank

get_mesh_rank()

get_gpu_id

get_gpu_id()

get_ray_node_id

get_ray_node_id()

get_master_addr_port

get_master_addr_port()

attr cfg

cfg = cfg

method init_model

init_model(*args, **kwargs)

Initialize worker state (model, and optimizer if applicable) on worker.

Source code in skyrl/backends/skyrl_train/workers/worker.py:240-242
    def init_model(self, *args, **kwargs):
        """Initialize worker state (model, and optimizer if applicable) on worker."""
        raise NotImplementedError()

method empty_cache

empty_cache() -> None

Empty GPU memory cache on Worker's CUDA device

Source code in skyrl/backends/skyrl_train/workers/worker.py:244-246
    def empty_cache(self) -> None:
        """Empty GPU memory cache on Worker's CUDA device"""
        torch.cuda.empty_cache()

method set_algorithm_config

set_algorithm_config(**kwargs) -> None
Source code in skyrl/backends/skyrl_train/workers/worker.py:248-250
    def set_algorithm_config(self, **kwargs) -> None:
        for key, value in kwargs.items():
            setattr(self.cfg.algorithm, key, value)

method offload_to_cpu

offload_to_cpu(offload_optimizer = True, offload_model = True)

Offload all worker state to CPU.

After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain.

Parameters:

NameTypeDescriptionDefault
offload_optimizerWhether to offload optimizer state (no-op when there is no optimizer, e.g. Ref worker).True
offload_modelWhether to offload model parameters.True
Source code in skyrl/backends/skyrl_train/workers/worker.py:256-271
    def offload_to_cpu(self, offload_optimizer=True, offload_model=True):
        """Offload all worker state to CPU.

        After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain.

        Args:
            offload_optimizer: Whether to offload optimizer state (no-op when there is no optimizer, e.g. Ref worker).
            offload_model: Whether to offload model parameters.
        """
        self._set_numa_affinity(torch.distributed.get_rank() % torch.cuda.device_count())
        self.strategy.offload_to_cpu(
            self._get_module_for_offload(),
            self.optimizer,
            offload_optimizer=offload_optimizer,
            offload_model=offload_model,
        )

method backload_to_gpu

backload_to_gpu(backload_optimizer = True, backload_model = True)

Backload worker state to GPU.

Parameters:

NameTypeDescriptionDefault
backload_optimizerWhether to backload optimizer state (no-op when there is no optimizer).True
backload_modelWhether to backload model parameters.True
Source code in skyrl/backends/skyrl_train/workers/worker.py:273-285
    def backload_to_gpu(self, backload_optimizer=True, backload_model=True):
        """Backload worker state to GPU.

        Args:
            backload_optimizer: Whether to backload optimizer state (no-op when there is no optimizer).
            backload_model: Whether to backload model parameters.
        """
        self.strategy.backload_to_gpu(
            self._get_module_for_offload(),
            self.optimizer,
            backload_optimizer=backload_optimizer,
            backload_model=backload_model,
        )

method get_cuda_memory

get_cuda_memory() -> Dict[str, Any]

Get CUDA memory usage on worker's CUDA device.

Source code in skyrl/backends/skyrl_train/workers/worker.py:287-296
    def get_cuda_memory(self) -> Dict[str, Any]:
        """Get CUDA memory usage on worker's CUDA device."""
        torch.cuda.synchronize()
        free, total = torch.cuda.mem_get_info()
        return {
            "allocated": torch.cuda.memory_allocated(),
            "reserved": torch.cuda.memory_reserved(),
            "free": free,
            "total": total,
        }

method save_memory_snapshot

save_memory_snapshot(tag: str = '')

Save a snapshot of memory usage on the Worker's CUDA device.

No-ops if record_memory is False.

Parameters:

NameTypeDescriptionDefault
tagstrLabel for the snapshot (e.g., "forward_backward", "optim_step")''

.. note:: This function should be called on all the ranks in the worker group simultaneously.

Source code in skyrl/backends/skyrl_train/workers/worker.py:298-329
    def save_memory_snapshot(self, tag: str = ""):
        """Save a snapshot of memory usage on the Worker's CUDA device.

        No-ops if record_memory is False.

        Args:
            tag: Label for the snapshot (e.g., "forward_backward", "optim_step")

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        if not self.record_memory:
            return

        # Track snapshot count for unique filenames
        if not hasattr(self, "_snapshot_count"):
            self._snapshot_count = 0
        self._snapshot_count += 1

        rank = torch.distributed.get_rank()
        save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
        if self._local_rank == 0 and not io.exists(save_path):
            io.makedirs(save_path, exist_ok=True)
        torch.distributed.barrier()

        tag_str = f"_{tag}" if tag else ""
        file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
        record_memory_path = os.path.join(save_path, file_name)
        if io.exists(record_memory_path):
            # seeing issues if we don't remove the file first
            io.remove(record_memory_path)
        torch.cuda.memory._dump_snapshot(record_memory_path)

method init_weight_sync_state

init_weight_sync_state(inference_engine_client: Union[InferenceEngineClient, RemoteInferenceClient], inference_engine_cfg: InferenceEngineConfig)

Initialize state for weight syncing with Inference Engine Client

Creates init info and sender, then sends init info to inference engines so they can create receivers.

.. note:: This function should be called on all the ranks in the worker group simultaneously.

Source code in skyrl/backends/skyrl_train/workers/worker.py:331-392
    async def init_weight_sync_state(
        self,
        inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
        inference_engine_cfg: "InferenceEngineConfig",
    ):
        """Initialize state for weight syncing with Inference Engine Client

        Creates init info and sender, then sends init info to inference engines
        so they can create receivers.

        .. note::
            This function should be called on all the ranks in the worker group simultaneously.
        """
        from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls

        assert inference_engine_client is not None

        # Determine transfer strategy based on inference engine config and placement
        self._transfer_strategy_cls = get_transfer_strategy_cls(
            weight_sync_backend=inference_engine_cfg.weight_sync_backend,
            colocate_all=self.cfg.placement.colocate_all,
        )

        # For new inference path, fetch world_size from servers
        # For legacy path, calculate from config
        inference_world_size = None
        if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
            inference_world_size, _ = await inference_engine_client.get_world_size()

        # Create init info on all ranks (it's deterministic from cfg or fetched world_size)
        init_info = self._transfer_strategy_cls.create_init_info(
            inference_engine_cfg, inference_world_size=inference_world_size
        )

        # Create sender on all ranks
        # Strategy implementations may have different logic for different ranks
        tasks = [
            asyncio.to_thread(
                self._transfer_strategy_cls.create_sender,
                init_info=init_info,
                inference_client=inference_engine_client,
            ),
        ]

        # Only rank 0 initializes receivers on inference engines
        # NOTE: For broadcast strategy, sender and receiver init must run concurrently
        # because both need to join the same process group to avoid deadlock
        if torch.distributed.get_rank() == 0:
            tasks.append(inference_engine_client.init_weight_update_communicator(init_info))

        results = await asyncio.gather(*tasks)
        self._weight_transfer_sender = results[0]  # sender is always first task

        # # Register signal handlers for termination only on rank 0
        # NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
        # The better way is to just have this specified in __del__, but there is
        # no guarattee that __del__ will be called in general. Ray also doesn't
        # explictly call __del__ when the actor shuts down.
        # It's commented out so that we can fix this in the future.
        # atexit.register(self._handle_termination)

        torch.distributed.barrier()

method abstractmethod forward

forward(*args, **kwargs) -> WorkerOutput

Run forward pass on the input batch.

Each worker subclass declares its own concrete signature and returns a :class:WorkerOutput so callers can program against a uniform API.

Source code in skyrl/backends/skyrl_train/workers/worker.py:394-400
    def forward(self, *args, **kwargs) -> WorkerOutput:
        """Run forward pass on the input batch.

        Each worker subclass declares its own concrete signature and returns a
        :class:`WorkerOutput` so callers can program against a uniform API.
        """
        raise NotImplementedError()

method abstractmethod save_checkpoint

save_checkpoint(ckpt_dir: str, tokenizer: str = None)
Source code in skyrl/backends/skyrl_train/workers/worker.py:405-413
    def save_checkpoint(self, ckpt_dir: str, tokenizer=None):
        self.strategy.save_checkpoint(
            model=self.model,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            ckpt_dir=ckpt_dir,
            node_local_rank=self.get_node_local_rank(),
            tokenizer=tokenizer,
        )

method abstractmethod load_checkpoint

load_checkpoint(ckpt_dir: str, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True)
Source code in skyrl/backends/skyrl_train/workers/worker.py:415-424
    def load_checkpoint(self, ckpt_dir: str, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True):
        _, states = self.strategy.load_checkpoint(
            model=self.model,
            optimizer=self.optimizer if load_optimizer_states else None,
            scheduler=self.scheduler if load_lr_scheduler_states else None,
            ckpt_dir=ckpt_dir,
            load_optimizer_states=load_optimizer_states,
            load_lr_scheduler_states=load_lr_scheduler_states,
        )
        return states

method save_hf_model

save_hf_model(export_dir: str, tokenizer: str)
Source code in skyrl/backends/skyrl_train/workers/worker.py:426-432
    def save_hf_model(self, export_dir: str, tokenizer):
        # Save model in HuggingFace safetensors format
        self.strategy.save_hf_model(
            self.model,
            export_dir,
            tokenizer=tokenizer,
        )

method get_lr

get_lr() -> Optional[float]

Get current learning rate from optimizer. Returns None when the worker was initialized with policy.inference_only_init=True (no optimizer constructed).

Source code in skyrl/backends/skyrl_train/workers/worker.py:434-441
    def get_lr(self) -> Optional[float]:
        """
        Get current learning rate from optimizer. Returns None when the worker was
        initialized with ``policy.inference_only_init=True`` (no optimizer constructed).
        """
        if self.optimizer is None:
            return None
        return self.optimizer.param_groups[0]["lr"]

method set_lr

set_lr(learning_rate: float) -> None

Set learning rate for the optimizer.

This directly updates the optimizer's param_groups, bypassing the scheduler. Useful for external learning rate schedules (e.g., from Tinker). No-op when policy.inference_only_init=True (no optimizer constructed).

Source code in skyrl/backends/skyrl_train/workers/worker.py:443-454
    def set_lr(self, learning_rate: float) -> None:
        """
        Set learning rate for the optimizer.

        This directly updates the optimizer's param_groups, bypassing the scheduler.
        Useful for external learning rate schedules (e.g., from Tinker). No-op when
        ``policy.inference_only_init=True`` (no optimizer constructed).
        """
        if self.optimizer is None:
            return
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = learning_rate

class PPORayActorGroup

PPORayActorGroup(cfg: TrainerConfig, num_nodes: TrainerConfig, num_gpus_per_node: TrainerConfig, ray_actor_type: Type[Worker], pg: Optional[ResolvedPlacementGroup] = None, num_gpus_per_actor: float = 1.0, resources: Optional[Dict[str, float]] = None, num_resources_per_node: Optional[int] = None, colocate_all: bool = False, sequence_parallel_size: int = 1, record_memory: bool = False) -> None

A group of ray actors Functions start with 'async' should return list of object refs

Parameters:

NameTypeDescriptionDefault
cfgTrainerConfigconfig object for workersrequired
num_nodesintNumber of nodes for this actor group.required
num_gpus_per_nodeintNumber of gpus for this actor group.required
ray_actor_typeType[Worker]PPO model type that this actor group serve on.required
pgResolvedPlacementGroupPlacement group to schedule actor on. If none, create new placement group automatically. Defaults to None.None
num_gpus_per_actorfloatNumber of gpus allocated for each actor. If < 1.0, multiple models can share same gpu. Defaults to 1.1.0

Functions:

NameDescription
async_init_modelAsynchronously initialize worker state (model, and optimizer if applicable) from model path
offload_to_cpuOffload all worker state to CPU.
backload_to_gpuBackload worker state to GPU
async_run_ray_methodRun a method on all actors using specified dispatch type asynchronously.

Attributes:

Parameters:

NameTypeDescriptionDefault
pgOptional[ResolvedPlacementGroup]Placement group for the worker group. Accepts a single PlacementGroup, or None. Note that if colocate_all is True, the number of bundles in the placement group must match world_size.None
Source code in skyrl/backends/skyrl_train/workers/worker.py:458-679
class PPORayActorGroup:
    """
    A group of ray actors
    Functions start with 'async' should return list of object refs

    Args:
        cfg: config object for workers
        num_nodes (int): Number of nodes for this actor group.
        num_gpus_per_node (int): Number of gpus for this actor group.
        ray_actor_type (Type[Worker]): PPO model type that this actor group serve on.
        pg (ResolvedPlacementGroup, optional): Placement group to schedule actor on.
            If none, create new placement group automatically. Defaults to None.
        num_gpus_per_actor (float, optional): Number of gpus allocated for each actor.
            If < 1.0, multiple models can share same gpu. Defaults to 1.
    """

    def __init__(
        self,
        cfg: TrainerConfig,
        num_nodes,
        num_gpus_per_node,
        ray_actor_type: Type[Worker],
        pg: Optional[ResolvedPlacementGroup] = None,
        num_gpus_per_actor: float = 1.0,
        resources: Optional[Dict[str, float]] = None,
        num_resources_per_node: Optional[int] = None,
        colocate_all: bool = False,
        sequence_parallel_size: int = 1,
        record_memory: bool = False,
    ) -> None:
        """
        Args:
            pg: Placement group for the worker group. Accepts a single PlacementGroup, or None.
                Note that if colocate_all is True, the number of bundles in the placement group must match world_size.
        """
        self.cfg = cfg
        self._num_nodes = num_nodes
        self._num_gpus_per_node = num_gpus_per_node
        self.ray_actor_type = ray_actor_type

        # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html
        self._resources = resources
        self._num_resources_per_node = num_resources_per_node

        self.colocate_all = colocate_all
        self.sequence_parallel_size = sequence_parallel_size
        self.record_memory = record_memory
        self._initiate_actors(pg, num_gpus_per_actor)

    def _initiate_actors(self, pg: Optional[ResolvedPlacementGroup], num_gpus_per_actor: float):
        """Initialize Ray actors in the worker group.

        Args:
            pg: A single placement group for the worker group, or None.
            num_gpus_per_actor: The number of gpus to allocate per actor.
        """
        world_size = self._num_nodes * self._num_gpus_per_node

        # Extract raw Ray PlacementGroup and pre-computed reordered indices from ResolvedPlacementGroup.
        # Only use reordered indices when the PG has one bundle per GPU (single-GPU bundles),
        # i.e. the bundle count matches world_size. Multi-GPU bundles (whole-node bundles)
        # don't need reordering since each bundle already represents a full node.
        reordered_bundle_indices = []
        raw_pg = None
        if pg is not None:
            assert isinstance(pg, ResolvedPlacementGroup), f"pg must be a `ResolvedPlacementGroup` got {type(pg)}."
            raw_pg = pg.pg
            if len(placement_group_table(raw_pg)["bundles"]) == world_size:
                reordered_bundle_indices = pg.reordered_bundle_indices

        if self.colocate_all:
            assert (
                raw_pg is not None
            ), "if colocate_all is True, the shared placement group must be provided to PPORayActorGroup"
            pg_data = placement_group_table(raw_pg)
            assert len(pg_data["bundles"]) == world_size, (
                f"if colocate_all is True, the number of bundles in the placement group "
                f"must match world_size. Got {len(pg_data['bundles'])} bundles but world_size={world_size}"
            )

        # If no PG provided, create one internally
        if raw_pg is None and self._num_gpus_per_node > 1:
            bundles = [{"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes)]
            if self._resources:
                resources_name = list(self._resources.keys())[0]
                for i in range(len(bundles)):
                    bundles[i][resources_name] = self._num_resources_per_node

            raw_pg = placement_group(bundles, strategy="PACK")
            get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)

        def _scheduling_strategy_for_rank(rank):
            if reordered_bundle_indices:
                return PlacementGroupSchedulingStrategy(
                    placement_group=raw_pg,
                    placement_group_bundle_index=reordered_bundle_indices[rank],
                )
            elif raw_pg is not None:
                return PlacementGroupSchedulingStrategy(
                    placement_group=raw_pg,
                    placement_group_bundle_index=rank // self._num_gpus_per_node,
                )
            # else we are in the single gpu case per node case in which case we don't need to set
            # bundle indices
            return None

        sched = _scheduling_strategy_for_rank(0)
        actor_options = {
            "num_cpus": num_gpus_per_actor,
            "num_gpus": num_gpus_per_actor,
            "resources": self._resources,
        }
        if sched is not None:
            actor_options["scheduling_strategy"] = sched

        master_actor = self.ray_actor_type.options(**actor_options).remote(
            cfg=self.cfg,
            world_size=world_size,
            rank=0,
            local_rank=0,
            master_addr=None,
            master_port=None,
            sequence_parallel_size=self.sequence_parallel_size,
            record_memory=self.record_memory,
        )
        self._actor_handlers = [master_actor]

        if world_size > 1:
            master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote())
            for rank in range(1, world_size):
                local_rank = rank % self._num_gpus_per_node

                sched = _scheduling_strategy_for_rank(rank)
                actor_options = {
                    "num_cpus": num_gpus_per_actor,
                    "num_gpus": num_gpus_per_actor,
                    "resources": self._resources,
                }
                if sched is not None:
                    actor_options["scheduling_strategy"] = sched

                worker_actor = self.ray_actor_type.options(**actor_options).remote(
                    cfg=self.cfg,
                    world_size=world_size,
                    rank=rank,
                    local_rank=local_rank,
                    master_addr=master_addr,
                    master_port=master_port,
                    sequence_parallel_size=self.sequence_parallel_size,
                    record_memory=self.record_memory,
                )
                self._actor_handlers.append(worker_actor)

        # Initialize process group
        logger.info("Initializing process group for RayActorGroup")
        ray.get([actor.init_worker_process_group.remote() for actor in self._actor_handlers])
        logger.info("Initialized process group for RayActorGroup")
        self.actor_infos = [ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) for actor in self._actor_handlers]
        logger.info(f"Mesh Ranks: {[actor_info.rank for actor_info in self.actor_infos]}")

    def async_init_model(
        self,
        *args,
        **kwargs,
    ) -> List[ObjectRef]:
        """Asynchronously initialize worker state (model, and optimizer if applicable) from model path
        on all the workers.

        Returns:
            A list of ray object refs.
        """
        return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]

    def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
        """Offload all worker state to CPU.

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of object refs.
        """
        refs = [
            actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

    def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
        """Backload worker state to GPU

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of ObjectRefs.
        """
        refs = [
            actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

    def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
        """Run a method on all actors using specified dispatch type asynchronously.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            List of object references
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        return object_refs

attr cfg

cfg = cfg

attr ray_actor_type

ray_actor_type = ray_actor_type

attr colocate_all

colocate_all = colocate_all

attr sequence_parallel_size

sequence_parallel_size = sequence_parallel_size

attr record_memory

record_memory = record_memory

method async_init_model

async_init_model(*args, **kwargs) -> List[ObjectRef]

Asynchronously initialize worker state (model, and optimizer if applicable) from model path on all the workers.

Returns:

TypeDescription
List[ObjectRef]A list of ray object refs.
Source code in skyrl/backends/skyrl_train/workers/worker.py:618-629
    def async_init_model(
        self,
        *args,
        **kwargs,
    ) -> List[ObjectRef]:
        """Asynchronously initialize worker state (model, and optimizer if applicable) from model path
        on all the workers.

        Returns:
            A list of ray object refs.
        """
        return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]

method offload_to_cpu

offload_to_cpu(nonblocking = False, offload_optimizer = True, offload_model = True)

Offload all worker state to CPU.

Parameters:

NameTypeDescriptionDefault
nonblockingWhether this operation is synchronous or asynchronous.False
Source code in skyrl/backends/skyrl_train/workers/worker.py:631-644
    def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
        """Offload all worker state to CPU.

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of object refs.
        """
        refs = [
            actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

method backload_to_gpu

backload_to_gpu(nonblocking = False, backload_optimizer = True, backload_model = True)

Backload worker state to GPU

Parameters:

NameTypeDescriptionDefault
nonblockingWhether this operation is synchronous or asynchronous.False
Source code in skyrl/backends/skyrl_train/workers/worker.py:646-659
    def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
        """Backload worker state to GPU

        Args:
            nonblocking: Whether this operation is synchronous or asynchronous.
            If `nonblocking=True`, then the function returns a list of ObjectRefs.
        """
        refs = [
            actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
            for actor in self._actor_handlers
        ]
        if nonblocking:
            return refs
        return ray.get(refs)

method async_run_ray_method

async_run_ray_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> List[ObjectRef]

Run a method on all actors using specified dispatch type asynchronously.

Parameters:

NameTypeDescriptionDefault
dispatch_typestrType of dispatch to use ("mesh" or "pass_through")required
method_namestrName of the method to call on actorsrequired
*argsPositional arguments to pass to the method()
**kwargsKeyword arguments to pass to the method{}

Returns:

TypeDescription
List[ObjectRef]List of object references
Source code in skyrl/backends/skyrl_train/workers/worker.py:661-679
    def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
        """Run a method on all actors using specified dispatch type asynchronously.

        Args:
            dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
            method_name: Name of the method to call on actors
            *args: Positional arguments to pass to the method
            **kwargs: Keyword arguments to pass to the method

        Returns:
            List of object references
        """
        dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
        # validate the dispatch args to be sent to `.dispatch`
        args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)

        # Dispatch the method call
        object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
        return object_refs

On this page

Trainer Classclass RayPPOTrainerattr cfgattr colocate_allattr trackerattr tokenizerattr train_datasetattr eval_datasetattr inference_engine_clientattr generatorattr train_dataloaderattr total_training_stepsattr eval_dataloaderattr colocate_pgattr resume_modeattr all_metricsattr all_timingsattr global_stepattr policy_modelattr critic_modelattr ref_modelattr dynamic_sampling_stateattr reward_kl_controllerattr dispatchmethod add_callbackattr property has_criticmethod async evalmethod trainmethod build_modelsmethod init_weight_sync_statemethod convert_to_training_inputmethod async generatemethod postprocess_generator_outputmethod compute_advantages_and_returnsmethod dump_datamethod fwd_logprobs_values_rewardmethod apply_reward_kl_penaltymethod train_critic_and_policymethod handle_dynamic_samplingmethod save_checkpointsmethod load_checkpointsmethod save_modelsmethod update_ref_with_policyDispatch APIsclass Dispatchattr dispatchmethod abstractmethod classmethod validate_dispatch_argsclass MeshDispatchattr dispatchmethod classmethod stage_chunksmethod classmethod dispatch_from_stagedmethod abstractmethod classmethod validate_dispatch_argsclass PassThroughDispatchattr dispatchmethod abstractmethod classmethod validate_dispatch_argsWorker APIsclass Workerattr sequence_parallel_sizeattr record_memoryget_node_local_rankinit_worker_process_groupget_mesh_rankget_gpu_idget_ray_node_idget_master_addr_portattr cfgmethod init_modelmethod empty_cachemethod set_algorithm_configmethod offload_to_cpumethod backload_to_gpumethod get_cuda_memorymethod save_memory_snapshotmethod init_weight_sync_statemethod abstractmethod forwardmethod abstractmethod save_checkpointmethod abstractmethod load_checkpointmethod save_hf_modelmethod get_lrmethod set_lrclass PPORayActorGroupattr cfgattr ray_actor_typeattr colocate_allattr sequence_parallel_sizeattr record_memorymethod async_init_modelmethod offload_to_cpumethod backload_to_gpumethod async_run_ray_method