Trainer
Trainer API — RayPPOTrainer, Dispatch, Worker APIs.
Trainer Class
class RayPPOTrainer
RayPPOTrainer(cfg: SkyRLTrainConfig, tracker: Tracking, tokenizer: AutoTokenizer, train_dataset: Optional[PromptDataset], inference_engine_client: InferenceEngineClient, generator: GeneratorInterface, colocate_pg: Optional[ResolvedPlacementGroup] = None, eval_dataset: Optional[PromptDataset] = None)Functions:
| Name | Description |
|---|---|
eval | Run generation and scoring on the evaluation dataset. |
train | Main training loop for PPO |
build_models | Initialize the actors for training, and handle colocation logic |
init_weight_sync_state | Setup the connection between policy model and inference engine for weight syncing. |
sync_policy_weights_to_inference_engines | Broadcast policy weights to inference engines. |
convert_to_training_input | Converts lists to a padded batch of tensors for training |
generate | Generate rollouts. |
postprocess_generator_output | Converts to per token rewards and computes pass@N. |
compute_advantages_and_returns | Calculate advantages and returns for the data batch. |
dump_data | Dump data to pickle file |
fwd_logprobs_values_reward | Calculate values from the critic, log probs from the policy and ref model. |
apply_reward_kl_penalty | Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards. |
train_critic_and_policy | Run the training step for the policy and critic models. |
handle_dynamic_sampling | Handle dynamic sampling for the current batch. |
save_checkpoints | Save the model, optimizer, and training states to disk. |
load_checkpoints | Load complete checkpoint state and return the global_step to resume from. |
save_models | Save the model parameters in HF format at cfg.trainer.export_path. |
update_ref_with_policy | Update the reference model with the policy model weights (required by some algorithms). |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
colocate_all | ||
tracker | ||
tokenizer | ||
train_dataset | ||
eval_dataset | ||
inference_engine_client | ||
generator | ||
train_dataloader | ||
total_training_steps | ||
eval_dataloader | ||
colocate_pg | ||
resume_mode | ||
all_metrics | ||
all_timings | ||
global_step | ||
policy_model | PPORayActorGroup | |
critic_model | Optional[PPORayActorGroup] | |
ref_model | Optional[PPORayActorGroup] | |
dynamic_sampling_state | Optional[DynamicSamplingState] | |
reward_kl_controller | Optional[Union[FixedKLController, AdaptiveKLController]] | |
dispatch | WorkerDispatch | |
has_critic | bool | Check if critic model is configured. |
Source code in skyrl/train/trainer.py:87-1496
class RayPPOTrainer:
def __init__(
self,
cfg: SkyRLTrainConfig,
tracker: Tracking,
tokenizer: AutoTokenizer,
train_dataset: Optional[PromptDataset],
inference_engine_client: InferenceEngineClient,
generator: GeneratorInterface,
colocate_pg: Optional[ResolvedPlacementGroup] = None,
eval_dataset: Optional[PromptDataset] = None,
):
self.cfg = cfg
self.colocate_all = cfg.trainer.placement.colocate_all
self.tracker = tracker
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.inference_engine_client = inference_engine_client
self.generator = generator
self.train_dataloader = None
self.total_training_steps = None
self._build_train_dataloader_and_compute_training_steps()
self.eval_dataloader = (
build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else None
)
self.colocate_pg = colocate_pg
self.resume_mode = ResumeMode(cfg.trainer.resume_mode)
self.all_metrics = {}
self.all_timings = {}
self.global_step = 0
# initialized in `build_models`
self.policy_model: PPORayActorGroup = None
self.critic_model: Optional[PPORayActorGroup] = None
self.ref_model: Optional[PPORayActorGroup] = None
# used for checkpoint cleanup
self._node_ids: Optional[List[str]] = None
self.dynamic_sampling_state: Optional[DynamicSamplingState] = None
self.reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None
self.dispatch: WorkerDispatch = None
configure_ray_worker_logging()
@property
def has_critic(self) -> bool:
"""Check if critic model is configured."""
return bool(self.cfg.trainer.critic.model.path)
def _build_train_dataloader_and_compute_training_steps(self):
"""
Hook for constructing the training dataloader. Subclasses can override
this to customize dataloader behavior. For instance, fully async training
needs a batch size of 1, among other features.
Defaults to `trainer_utils.build_dataloader` with `is_train=True`.
When train_dataset is None (e.g. Tinker backend provides data externally),
the dataloader is not built.
"""
if self.train_dataset is not None:
self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True)
self.total_training_steps = len(self.train_dataloader) * self.cfg.trainer.epochs
@torch.no_grad()
async def eval(self) -> Dict[str, float]:
"""
Run generation and scoring on the evaluation dataset.
The eval metrics are recorded after having finished training `self.global_step` steps.
Metrics recorded in global_step 0 corresponds to evaluations before training.
Returns:
A dictionary of evaluation metrics.
"""
if self.cfg.generator.step_wise_trajectories:
eval_metrics = await evaluate_step_wise(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
else:
eval_metrics = await evaluate(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
return eval_metrics
async def train(self):
"""
Main training loop for PPO
"""
# Initialize weight sync state between policy model and inference engines.
with Timer("init_weight_sync_state"):
self.init_weight_sync_state()
# Load checkpoint state if resumption is enabled.
if self.resume_mode != ResumeMode.NONE:
with Timer("load_checkpoints"):
self.global_step, _ = self.load_checkpoints()
# Prepare weights for sampling
with Timer("sync_weights"):
await self.dispatch.save_weights_for_sampler()
# Eval before training
if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.tracker.log(eval_metrics, step=self.global_step, commit=True)
# initialize kl controller
if self.cfg.trainer.algorithm.use_kl_in_reward:
self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)
# main training loop
pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
start_epoch = self.global_step // len(self.train_dataloader)
self.global_step += 1 # start training at global_step 1
for epoch in range(start_epoch, self.cfg.trainer.epochs):
for _, rand_prompts in enumerate(self.train_dataloader):
with Timer("step", self.all_timings):
# for colocate_all=true, inference engine is always on GPU when starting the training step
# 0. truncate data to have even shards
rand_prompts = self._remove_tail_data(rand_prompts)
generator_input, uids = prepare_generator_input(
rand_prompts,
self.cfg.generator.n_samples_per_prompt,
get_sampling_params_for_backend(
self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
),
self.cfg.environment.env_class,
"train",
self.global_step,
)
# 1.1. generation phase
with Timer("generate", self.all_timings):
generator_output: GeneratorOutput = await self.generate(generator_input)
if self.cfg.generator.step_wise_trajectories:
# NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
# this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]
# dynamic sampling
if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
if keep_sampling: # continue sampling
# update progress bar for current batch (but not global step)
pbar.update(1)
continue
if self.colocate_all:
# if we are not continuing sampling, we sleep the inference engine
await self.inference_engine_client.sleep()
# 1.2 postprocess rewards (and merge step-wise turns if enabled)
with Timer("postprocess_generator_output", self.all_timings):
generator_output, uids = self.postprocess_generator_output(generator_output, uids)
# 2. print example just for debugging
vis = self.tokenizer.decode(generator_output["response_ids"][0])
log_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)
# 3. Convert GeneratorOutput to TrainingInputBatch
with Timer("convert_to_training_input", self.all_timings):
training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)
# 4. Inference and calculate values, log probs, rewards, kl divergence
with Timer("fwd_logprobs_values_reward", self.all_timings):
training_input = self.fwd_logprobs_values_reward(training_input)
# 5. apply kl divergence penalty to rewards
if self.cfg.trainer.algorithm.use_kl_in_reward:
with Timer("apply_reward_kl_penalty", self.all_timings):
training_input = self.apply_reward_kl_penalty(training_input)
# 6. calculate advantages and returns
with Timer("compute_advantages_and_returns", self.all_timings):
training_input = self.compute_advantages_and_returns(training_input)
# remove some unwanted keys
for key in ["rewards"]:
training_input.pop(key)
training_input.metadata.pop("uids")
training_input.metadata.pop("is_last_step", None)
if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")
# 7. train policy/critic model
# Policy model is backloaded to GPU during training
with Timer("train_critic_and_policy", self.all_timings):
status = self.train_critic_and_policy(training_input)
# 8. conditionally save checkpoints and hf model
is_epoch_end = self.global_step % len(self.train_dataloader) == 0
if self.cfg.trainer.ckpt_interval > 0:
if is_epoch_end or self.global_step % self.cfg.trainer.ckpt_interval == 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
if self.cfg.trainer.hf_save_interval > 0:
if is_epoch_end or self.global_step % self.cfg.trainer.hf_save_interval == 0:
with Timer("save_hf_model", self.all_timings):
self.save_models()
# 9. conditionally sync policy and ref at the end of the epoch
if (
self.cfg.trainer.update_ref_every_epoch
and self.ref_model is not None
and is_epoch_end
and epoch != self.cfg.trainer.epochs - 1 # skip updating ref at the end of the last epoch
):
with Timer("update_ref_with_policy", self.all_timings):
self.update_ref_with_policy()
# 10. Prepare weights for sampling
with Timer("sync_weights", self.all_timings):
await self.dispatch.save_weights_for_sampler()
# 11. set logs
logger.info(status)
# log epoch info
self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
if self.cfg.trainer.eval_interval > 0 and (
self.global_step % self.cfg.trainer.eval_interval == 0
or self.global_step == self.total_training_steps
):
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.all_metrics.update(eval_metrics)
log_payload = {
**self.all_metrics,
**{f"timing/{k}": v for k, v in self.all_timings.items()},
}
self.tracker.log(log_payload, step=self.global_step, commit=True)
self.all_metrics = {}
self.all_timings = {}
# update progress bar after logging
pbar.update(1)
self.global_step += 1
del training_input, generator_output
pbar.close()
if self.colocate_all:
await self.inference_engine_client.sleep()
# Safety net: always save final checkpoint at end of training.
if self.cfg.trainer.ckpt_interval > 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
logger.info("Saved final checkpoint.")
if self.cfg.trainer.hf_save_interval > 0:
with Timer("save_hf_model", self.all_timings):
self.save_models()
logger.info("Saved final model.")
self.tracker.finish()
logger.info("Training done!")
def _remove_tail_data(self, entries: List[Any]) -> List[Any]:
"""Remove tail data to have even shards in terms of *effective* samples.
Each prompt produces `n_samples_per_prompt` samples. For data-parallel
training we care that the total number of samples is nicely splittable
across the (combined) data-parallel size of all enabled models.
"""
lcm_dp_size = self.dispatch.get_lcm_dp_size()
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
# We want the largest m <= len(entries) such that:
# (m * n_samples_per_prompt) % lcm_dp_size == 0
#
# Let g = gcd(lcm_dp_size, n_samples_per_prompt). Then this is equivalent
# to requiring m to be a multiple of (lcm_dp_size / g).
stride = lcm_dp_size // math.gcd(lcm_dp_size, n_samples_per_prompt)
if stride <= 1:
# Every prompt count is valid, keep all entries.
return entries
kept_prompts = (len(entries) // stride) * stride
return entries[:kept_prompts]
def build_models(self, PolicyWorker, CriticWorker, RefWorker):
"""
Initialize the actors for training, and handle colocation logic
"""
cfg = self.cfg
pg = None
use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward
if cfg.trainer.placement.colocate_all:
num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
ie_cfg = cfg.generator.inference_engine
num_rollout_gpus = (
ie_cfg.num_engines
* ie_cfg.tensor_parallel_size
* ie_cfg.pipeline_parallel_size
* ie_cfg.data_parallel_size
)
assert (
num_policy_gpus == num_rollout_gpus
), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
pg = self.colocate_pg
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
)
if use_ref_model:
assert (
num_policy_gpus == num_ref_gpus
), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
assert (
num_policy_gpus == num_critic_gpus
), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
pg=pg,
num_gpus_per_actor=0.2,
colocate_all=True,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
else:
if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
assert (
cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."
bundles = [
{
"GPU": cfg.trainer.placement.policy_num_gpus_per_node,
"CPU": cfg.trainer.placement.policy_num_gpus_per_node,
}
for _ in range(cfg.trainer.placement.policy_num_nodes)
]
raw_pg = placement_group(bundles, strategy="PACK")
get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
pg = ResolvedPlacementGroup(raw_pg)
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.75 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
)
if use_ref_model:
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.25 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
num_gpus_per_actor=1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
policy_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
critic_steps_per_train_batch = 0
if cfg.trainer.critic.model.path:
critic_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
policy_num_training_steps = (
self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
)
critic_num_training_steps = (
self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
)
if not cfg.trainer.placement.colocate_all:
refs = []
if ref_model is not None:
refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
refs.extend(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
if cfg.trainer.critic.model.path:
refs.extend(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
ray.get(refs)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
else:
if ref_model is not None:
ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
ref_model.offload_to_cpu()
ray.get(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
policy_model.offload_to_cpu()
if cfg.trainer.critic.model.path:
ray.get(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
critic_model.offload_to_cpu()
self.policy_model: PPORayActorGroup = policy_model
self.critic_model: Optional[PPORayActorGroup] = critic_model
self.ref_model: Optional[PPORayActorGroup] = ref_model
# Create unified dispatch that manages all actor groups
self.dispatch = WorkerDispatch(
cfg=self.cfg,
policy_actor_group=policy_model,
critic_actor_group=critic_model,
ref_actor_group=ref_model,
inference_engine_client=self.inference_engine_client,
)
# Mark all models as offloaded if colocate_all (they were offloaded above)
if self.colocate_all:
self.dispatch.mark_all_offloaded()
logger.info("init policy/ref/critic models done")
def init_weight_sync_state(self):
"""
Setup the connection between policy model and inference engine for weight syncing.
"""
self.dispatch.init_weight_sync_state(self.inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")
def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:
"""Broadcast policy weights to inference engines.
Note: For new code, prefer using dispatch.save_weights_for_sampler() which
handles the full weight sync protocol including offload/backload.
This method is kept for backward compatibility with subclasses.
TODO(tgriggs): Remove this method when migration is complete.
"""
return self.policy_model.async_run_ray_method(
"pass_through",
"broadcast_to_inference_engines",
self.inference_engine_client,
self.cfg.generator.inference_engine,
)
def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training
Args:
generator_output (GeneratorOutput): Generated rollouts and associated data.
uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same
order as `generator_output`. Used to identify which prompt each generated rollout belongs to.
Returns:
training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
order of `generator_output` and hence `uids`.
"""
# 1. Extract generator output fields.
prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
response_ids: List[List[int]] = generator_output["response_ids"]
rewards: List[List[float]] = generator_output["rewards"]
loss_masks: List[List[int]] = generator_output["loss_masks"]
logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)
rollout_expert_indices: Optional[List[List[List[List[int]]]]] = generator_output.get(
"rollout_expert_indices", None
)
pixel_values = generator_output.get("pixel_values", None)
image_grid_thw = generator_output.get("image_grid_thw", None)
if pixel_values is not None:
assert (
pixel_values is not None and image_grid_thw is not None
), "Both pixel_values and image_grid_thw must exist for multi-modal inputs"
assert len(pixel_values) == len(
image_grid_thw
), "Number of pixel values should match number of image grid thw"
pixel_values = TensorList(pixel_values)
image_grid_thw = TensorList(image_grid_thw)
# 2. Convert to tensors.
(
sequences_tensor,
attention_masks_tensor,
response_masks_tensor,
rewards_tensor,
loss_masks_tensor,
rollout_logprobs_tensor,
rollout_expert_indices_tensor,
) = convert_prompts_responses_to_batch_tensors(
self.tokenizer,
prompt_ids,
response_ids,
rewards,
loss_masks,
logprobs,
rollout_expert_indices,
max_seq_len=self.cfg.trainer.algorithm.max_seq_len,
)
# sanity check for off_policy_correction
off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
tis_ratio_type = off_policy_correction.tis_ratio_type
sequence_mask_metric = off_policy_correction.sequence_mask_metric
if tis_ratio_type is not None or sequence_mask_metric is not None:
assert (
rollout_logprobs_tensor is not None
), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"
# 3. Create training input batch.
training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
"attention_mask": attention_masks_tensor,
"response_mask": response_masks_tensor,
"rewards": rewards_tensor,
"loss_mask": loss_masks_tensor,
"rollout_logprobs": rollout_logprobs_tensor,
"rollout_expert_indices": rollout_expert_indices_tensor,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
},
)
training_input.metadata = {"uids": uids}
if generator_output.get("is_last_step", None) is not None:
training_input.metadata["is_last_step"] = generator_output["is_last_step"]
# 4. Compute mini-batch boundaries for train_critic_and_policy(). It excludes the ones
# we will add in pad_training_input_batch().
train_batch_size = self.cfg.trainer.train_batch_size
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
is_stepwise = self.cfg.generator.step_wise_trajectories
training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
)
if self.cfg.trainer.critic.model.path is not None:
training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
)
# 5. Record metadata and metrics.
training_input.metadata["response_length"] = response_masks_tensor.shape[1]
batch_num_seq, batch_padded_seq_len = sequences_tensor.shape
logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}")
self.all_metrics.update(
{
"generate/batch_num_seq": batch_num_seq,
"generate/batch_padded_seq_len": batch_padded_seq_len,
}
)
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)
# 6. Pad the batch, only needed for step-wise training's `fwd_logprobs_values_reward()`.
logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
training_input = pad_training_input_batch(training_input, pad_size)
logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")
return training_input
@torch.no_grad()
async def generate(
self,
input_batch: GeneratorInput,
) -> GeneratorOutput:
"""
Generate rollouts.
If colocate_all is enabled:
- before calling this method, the policy model should be on CPU and inference engine should
be awake (i.e. on GPU).
- after calling this method, the same model placement still holds.
"""
# NOTE: we assume that .generate returns samples in the same order as passed in
generator_output: GeneratorOutput = await self.generator.generate(input_batch)
# add rollout metrics to self.all_metrics
if generator_output["rollout_metrics"] is not None:
self.all_metrics.update(generator_output["rollout_metrics"])
generator_output.pop("rollout_metrics", None)
validate_generator_output(
len(input_batch["prompts"]),
generator_output,
step_wise=self.cfg.generator.step_wise_trajectories,
)
return generator_output
@torch.no_grad()
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
"""
Converts to per token rewards and computes pass@N.
For step-wise training with ``merge_stepwise_output=true``, also collapses
consecutive turns sharing a common prefix into a single sequence; ``uids``
is shortened to match.
In the future algorithm specific reward or loss mask post processing should be done here.
Returns:
(generator_output, uids) — uids may be shorter than the input when merging.
"""
generator_output_for_metrics = generator_output
uids_for_metrics = uids
if self.cfg.generator.step_wise_trajectories:
generator_output_for_metrics = defaultdict(list)
for key in generator_output:
if isinstance(generator_output[key], list):
generator_output_for_metrics[key] = [
generator_output[key][i]
for i in range(len(generator_output[key]))
if generator_output["is_last_step"][i]
]
uids_for_metrics = [
uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
]
# only use `generator_output_for_metrics` for metrics calculation
# For step-wise training, we only calculate metrics for the last step of each trajectory
overall_metrics = get_metrics_from_generator_output(
generator_output_for_metrics,
uids_for_metrics,
)
# Prefix-aware merging of step-wise turns.
if self.cfg.generator.merge_stepwise_output:
assert self.cfg.generator.step_wise_trajectories, "merge_stepwise_output requires step-wise training"
num_seq_before_merge = len(generator_output["response_ids"])
generator_output = merge_stepwise_output(generator_output)
num_seq_after_merge = len(generator_output["response_ids"])
logger.info(f"Merged step wise: {num_seq_before_merge} sequences -> {num_seq_after_merge} sequences")
self.all_metrics.update(
{
"generate/num_seq_before_merge": num_seq_before_merge,
"generate/num_seq_after_merge": num_seq_after_merge,
}
)
uids = [tid.instance_id for tid in generator_output["trajectory_ids"]]
# these use the full generator output
rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
responses: List[List[int]] = generator_output["response_ids"]
per_token_rewards: List[List[float]] = []
# Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
if rewards and isinstance(rewards[0], list):
# Token-level rewards: rewards is List[List[float]]
per_token_rewards = rewards
else:
if self.cfg.trainer.algorithm.zero_variance_filter:
kept_indices_set = set(zero_variance_filter(rewards, uids))
generator_output["loss_masks"] = [
[0] * len(mask) if i not in kept_indices_set else mask
for i, mask in enumerate(generator_output["loss_masks"])
]
# Response-level rewards: rewards is List[float], convert to per-token rewards
for reward, response in zip(rewards, responses):
per_token_reward = [0.0] * len(response)
per_token_reward[-1] = float(reward)
per_token_rewards.append(per_token_reward)
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
reward_metrics = {
f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
"reward/avg_raw_reward": overall_metrics["avg_score"],
"reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
}
self.all_metrics.update(reward_metrics)
logger.info(
f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
)
# re-assign reward but now it's per token rewards
generator_output["rewards"] = per_token_rewards
return generator_output, uids
@torch.no_grad()
def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
"""Calculate advantages and returns for the data batch.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
- `.metadata["uids"]`: List[str]
- `.metadata["is_last_step"]`: List[bool] for step-wise training
Adds:
- `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
token_level_rewards = data["rewards"]
if self.cfg.generator.step_wise_trajectories:
is_last_step = torch.tensor(data.metadata["is_last_step"], dtype=torch.bool)
index = np.array(data.metadata["uids"])
values = data["values"]
# Step-wise only supports outcome-based estimators (GRPO, RLOO, MAXRL); ensured by `validate_cfg`.
# We use the last step of each trajectory to compute advantages and broadcast them to
# all steps of that trajectory, so we ignore per-step rewards in step-wise training.
# We pass an all-ones mask here so the estimator returns the scalar advantage at every
# position. The real per-step `response_mask` is re-applied on broadcast below.
# Shapes:
# traj_ids, (batch_size,): trajectory id per step (cumsum of shifted is_last_step)
# last_step_advantages/returns,
# (num_traj, seqlen): scalar advantage/return per trajectory at every position
# last_step_advantages/returns[traj_ids],
# (batch_size, seqlen): broadcast to every step of the owning trajectory
# response_mask_float,
# (batch_size, seqlen): per-step response mask
last_step_response_mask = data["response_mask"][is_last_step]
last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards[is_last_step],
response_mask=torch.ones_like(last_step_response_mask, dtype=torch.float),
index=index[is_last_step.cpu().numpy()],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
values=values[is_last_step] if values is not None else None,
config=self.cfg.trainer.algorithm,
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
traj_ids = (
torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
)
num_traj = traj_ids[-1].item() + 1
assert num_traj == len(
last_step_advantages
), f"num_traj {num_traj} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
response_mask_float = data["response_mask"].to(last_step_advantages.dtype)
advantages = last_step_advantages[traj_ids] * response_mask_float
returns = last_step_returns[traj_ids] * response_mask_float
else:
advantages, returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
response_mask=data["response_mask"],
index=data.metadata["uids"],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
config=self.cfg.trainer.algorithm,
values=data["values"],
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
data["returns"] = returns
data["advantages"] = advantages
# remove padding while calculating metrics
pad_size = data.metadata.get("pad_size", 0)
num_samples = len(token_level_rewards)
return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
if self.cfg.generator.step_wise_trajectories:
avg_rewards: float = return_sums[is_last_step[: num_samples - pad_size]].mean().item()
else:
avg_rewards: float = return_sums.mean().item()
avg_response_length = data.metadata["avg_response_length"]
data = data.to("cpu")
valid_advantages = torch.masked_select(
data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
)
avg_advantages: float = valid_advantages.mean().item()
avg_advantages_abs: float = valid_advantages.abs().mean().item()
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_final_rewards": avg_rewards,
"avg_response_length": avg_response_length,
"avg_advantages": avg_advantages,
"avg_advantages_abs": avg_advantages_abs,
}
)
logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
self.all_metrics.update(
{
"loss/avg_final_rewards": avg_rewards,
"loss/avg_raw_advantages": avg_advantages,
"loss/avg_raw_advantages_abs": avg_advantages_abs,
}
)
return data
def dump_data(self, data: TrainingInputBatch, file_name: str):
"""
Dump data to pickle file
"""
data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
data_save_dir.mkdir(parents=True, exist_ok=True)
data.save(data_save_dir / f"{file_name}.pkl")
@torch.no_grad()
def fwd_logprobs_values_reward(
self,
training_input: TrainingInputBatch,
):
"""
Calculate values from the critic, log probs from the policy and ref model.
Dispatch handles offload/backload automatically for all colocation configurations.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `.metadata["response_length"]`: Int
Adds:
- `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
fwd_keys = ["sequences", "attention_mask"]
if training_input.get("rollout_expert_indices") is not None:
fwd_keys.append("rollout_expert_indices")
if training_input.get("pixel_values") is not None:
fwd_keys.append("pixel_values")
if training_input.get("image_grid_thw") is not None:
fwd_keys.append("image_grid_thw")
data_fwd_pass = training_input.select(keys=fwd_keys, metadata_keys=["response_length"])
values = None
base_log_probs = None
action_log_probs = None
# Critic forward (dispatch handles offload/backload automatically)
if self.has_critic:
critic_output = self.dispatch.forward("critic", data_fwd_pass)
values = critic_output["output"]
# Ref forward
if self.ref_model is not None:
ref_output = self.dispatch.forward("ref", data_fwd_pass)
base_log_probs = ref_output["output"]
self.dispatch.empty_cache("ref")
# Policy forward
policy_output = self.dispatch.forward("policy", data_fwd_pass)
action_log_probs = policy_output["output"]
# Empty cache after all forward passes
self.dispatch.empty_cache()
sequences_all: torch.Tensor = training_input["sequences"]
# NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
action_log_probs = action_log_probs[: len(sequences_all)]
values = values[: len(sequences_all)] if values is not None else None
training_input["base_action_log_probs"] = base_log_probs
training_input["action_log_probs"] = action_log_probs
training_input["values"] = values
if training_input.get("rollout_logprobs", None) is not None:
# calculates the difference in probs between inference and trainer components
# only consider response tokens
logprobs_diff = (
training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
- action_log_probs[training_input["loss_mask"] > 0]
).abs()
logprobs_diff_mean = logprobs_diff.mean().item()
logprobs_diff_std = logprobs_diff.std().item()
self.all_metrics.update(
{
"policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
"policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
}
)
return training_input
def apply_reward_kl_penalty(
self,
data: TrainingInputBatch,
) -> TrainingInputBatch:
"""Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
loss_masks_all: torch.Tensor = data["loss_mask"]
rewards: torch.Tensor = data["rewards"]
base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
action_log_probs: torch.Tensor = data["action_log_probs"]
# single batched computation
with torch.no_grad():
kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl( # type: ignore
action_log_probs,
base_action_log_probs,
loss_mask=loss_masks_all,
kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
)
kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0] # noqa: F821
kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1) # noqa: F821
# NOTE (erictang000): only supporting custom rewards currently
kl_loss_coef = (
self.reward_kl_controller.value
if self.reward_kl_controller is not None
else self.cfg.trainer.algorithm.kl_loss_coef
)
rewards = rewards - kl * max(0, kl_loss_coef)
data["rewards"] = rewards
avg_kl: float = kl_mean.mean().item()
avg_kl_max: float = kl_max.mean().item()
# update the kl controller
if self.reward_kl_controller is not None:
self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0]) # n_steps is just the batch size
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_kl": avg_kl,
"avg_kl_max": avg_kl_max,
"kl_loss_coef": kl_loss_coef,
}
)
self.all_metrics.update(
{
"loss/avg_kl": avg_kl,
"loss/avg_kl_max": avg_kl_max,
"loss/kl_loss_coef": kl_loss_coef,
}
)
return data
@torch.no_grad()
def _normalize_advantages(
self,
data: TrainingInputBatch,
mini_batch_boundaries: List[Tuple[int, int]],
) -> TrainingInputBatch:
advantages = data["advantages"]
response_mask = data["response_mask"]
# Step 1: Z-score normalization (if enabled)
if self.cfg.trainer.algorithm.advantage_batch_normalize:
num_actions = response_mask.sum()
mean = advantages.mean()
std = ((advantages - mean).pow(2) * response_mask).sum()
rstd = (std / num_actions).clamp(min=1e-8).rsqrt()
data["advantages"] = (advantages - mean) * rstd
# Step 2: Loss reduction normalization per mini-batch
normalized_advantages = torch.zeros_like(advantages)
for start_idx, end_idx in mini_batch_boundaries:
mini_batch = data[start_idx:end_idx]
normalized_advantages[start_idx:end_idx] = apply_loss_reduction_to_advantages_minibatch(
advantages=mini_batch["advantages"],
loss_mask=mini_batch["loss_mask"],
loss_reduction=self.cfg.trainer.algorithm.loss_reduction,
micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu,
max_seq_len=self.cfg.trainer.algorithm.max_seq_len,
)
data["advantages"] = normalized_advantages
return data
def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]:
"""
Execute training step using forward_backward + optim_step.
The trainer loops over epochs and mini-batches. Workers handle micro-batching
internally for gradient accumulation (memory efficiency).
All per-DP mini-batch chunks are pre-staged in the Ray object store before
the training loop so serialization stays off the GPU critical path.
Args:
model: Model name ("policy" or "critic")
data: Training data batch
Returns:
Dict of reduced metrics from training
"""
boundaries = data.metadata[f"{model}_mini_batch_boundaries"]
if model == "policy":
# Normalize advantages for policy training; critic training does not need this
data = self._normalize_advantages(data, boundaries)
all_metrics: Dict[str, List[float]] = defaultdict(list)
# Pre-stage all per-DP mini-batch chunks in the object store so that
# serialization is fully off the critical path during training.
all_chunk_refs = self.dispatch.stage_data(model, data, boundaries)
# Training loop over epochs and mini-batches
for _epoch in range(self.cfg.trainer.update_epochs_per_batch):
for chunk_refs in all_chunk_refs:
status = self.dispatch.forward_backward_from_staged(model, chunk_refs)
for k, v in status.items():
all_metrics[k].append(v)
# Optimizer step after each mini batch
grad_norm = self.dispatch.optim_step(model)
if grad_norm is not None:
all_metrics["grad_norm"].append(grad_norm)
# Reduce metrics across all mini-batches and epochs
# pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it
all_metrics.pop("loss_fn_outputs", None)
reduced_metrics = reduce_metrics(all_metrics, sum_loss_metrics=False)
return reduced_metrics
def train_critic_and_policy(self, data: TrainingInputBatch):
"""
Run the training step for the policy and critic models.
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
"""
data.metadata["global_step"] = self.global_step
critic_status = None
# Unified training interface for both FSDP and Megatron
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)
# Update metrics
if critic_status is not None:
for k, v in critic_status.items():
self.all_metrics.update({f"critic/{k}": v})
for k, v in policy_status.items():
self.all_metrics.update({f"policy/{k}": v})
self.dispatch.empty_cache()
return policy_status
def handle_dynamic_sampling(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str], bool]:
"""
Handle dynamic sampling for the current batch.
Accumulates the generator output and UIDs across batches if we are sampling repeatedly
and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
If we hit the limit of max sample batches, we raise an error.
Args:
generator_output: Current batch generator output
uids: Current batch UIDs
Returns:
processed_output: Filtered generator output
processed_uids: Filtered UIDs
keep_sampling: Whether to keep sampling
"""
# Prepare sampling configuration
max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
dynamic_sampling_config = {
"type": self.cfg.trainer.algorithm.dynamic_sampling.type,
"max_sample_batches": max_sample_batches,
"min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
"train_batch_size": self.cfg.trainer.train_batch_size,
"n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
}
if self.dynamic_sampling_state is None:
self.dynamic_sampling_state: DynamicSamplingState = {
"sample_batch_count": 1,
}
else:
self.dynamic_sampling_state["sample_batch_count"] += 1
# Handle dynamic sampling using utilities
processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
)
# Check max resample limit, and if we hit it, raise an error
if (
keep_sampling
and max_sample_batches > 0
and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
):
raise RuntimeError(
f"Exiting training loop due to hitting dynamic sampling limit for "
f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
f"Please check your data difficulty distribution."
)
# Update state
self.dynamic_sampling_state = updated_state
if not keep_sampling:
# Reset state when sampling is complete
self.dynamic_sampling_state = None
return processed_output, processed_uids, keep_sampling
def _get_dp_group_models(self, rank: int, model_type: str = ""):
model = getattr(self, model_type)
return model._actor_handlers[rank]
def _get_mesh_rank(self, rank: int, model_type: str = "") -> MeshRank:
model: PPORayActorGroup = getattr(self, model_type)
actor_info: ActorInfo = model.actor_infos[rank]
return actor_info.rank
def save_checkpoints(self):
"""
Save the model, optimizer, and training states to disk.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
# Create global step folder structure
global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
critic_save_dir = os.path.join(global_step_folder, "critic")
io.makedirs(global_step_folder, exist_ok=True)
# Save policy checkpoint (dispatch handles offload/backload)
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save critic checkpoint (if it exists)
if self.has_critic:
self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)
# Save dataloader state
dataloader_save_path = os.path.join(global_step_folder, "data.pt")
try:
dataloader_state_dict = self.train_dataloader.state_dict()
with io.open_file(dataloader_save_path, "wb") as f:
torch.save(dataloader_state_dict, f)
logger.info(f"Saved dataloader state to {dataloader_save_path}")
except Exception as e:
logger.warning(f"Failed to save dataloader state: {e}")
# Save additional trainer state
trainer_state = {
"global_step": self.global_step,
"config": asdict(self.cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking - write this last after all saves succeed
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_checkpoint_file, "w") as f:
f.write(str(self.global_step))
logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")
# Clean up old checkpoints after successful save
with Timer("cleanup_old_checkpoints", self.all_timings):
self._cleanup_old_checkpoints()
def _cleanup_old_checkpoints(self):
if not self._node_ids:
self._node_ids = self.dispatch.get_node_ids()
run_on_each_node(
self._node_ids,
cleanup_old_checkpoints,
self.cfg.trainer.ckpt_path,
self.cfg.trainer.max_ckpts_to_keep,
)
# run on driver as well
# NOTE (sumanthrh): the function will get called twice on the node with driver process, but it's ok because it's idempotent
cleanup_old_checkpoints(self.cfg.trainer.ckpt_path, self.cfg.trainer.max_ckpts_to_keep)
def load_checkpoints(self) -> Tuple[int, str]:
"""
Load complete checkpoint state and return the global_step to resume from.
Returns 0 if no checkpoint is loaded.
If colocate_all is True, assumes that the policy model is currently on GPU.
Returns:
global_step: The global step to resume from.
checkpoint_path: The path to the checkpoint.
"""
checkpoint_path = None
# Check if resumption is enabled
if self.resume_mode == ResumeMode.NONE:
logger.info("Checkpoint resumption disabled, starting training from scratch")
return 0, None
# first, let's get resume_path
elif self.resume_mode == ResumeMode.LATEST:
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_checkpoint_file):
logger.info("No checkpoint found, starting training from scratch")
return 0, None
with io.open_file(latest_checkpoint_file, "r") as f:
ckpt_iteration = int(f.read().strip())
checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
# Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
validate_consistency_for_latest_checkpoint(
self.cfg.trainer.ckpt_path,
ckpt_iteration,
checkpoint_path,
latest_checkpoint_file,
self.cfg.trainer.ckpt_interval,
)
else:
# Get and validate resume path
checkpoint_path = Path(self.cfg.trainer.resume_path)
if not checkpoint_path:
raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")
# Validate that it's a global_step directory
if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
raise ValueError(
f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
)
# Validate that the path exists
if not io.exists(str(checkpoint_path)):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from: {checkpoint_path}")
# Extract global step from checkpoint path
global_step = extract_step_from_path(Path(checkpoint_path))
if global_step == -1:
raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
logger.info(f"Resuming from global_step: {global_step}")
# Define paths for different checkpoint components
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
dataloader_state_path = os.path.join(checkpoint_path, "data.pt")
# Validate that required checkpoint files exist
if not io.exists(trainer_state_path):
raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")
# 1. Load and validate trainer state
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")
# 2. Load dataloader state if available
if io.exists(dataloader_state_path):
try:
with io.open_file(dataloader_state_path, "rb") as f:
dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state)
logger.info("Successfully loaded dataloader state")
except Exception as e:
logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
else:
logger.warning(
f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
)
# 3. Load policy checkpoint (dispatch handles offload/backload)
logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded policy checkpoint")
# 4. Load critic checkpoint if it exists and we have a critic model
if self.has_critic:
logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
self.dispatch.load_checkpoint(
"critic",
critic_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded critic checkpoint")
logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
return global_step, str(checkpoint_path)
def save_models(self):
"""
Save the model parameters in HF format at `cfg.trainer.export_path`.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
if self.has_critic:
critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)
logger.info("Successfully saved model weights.")
def update_ref_with_policy(self):
"""
Update the reference model with the policy model weights (required by some algorithms).
Dispatch handles offload/backload automatically for all colocation configurations.
After this method, save_weights_for_sampler() should be called to sync weights.
"""
# TODO(tgriggs): Make policy-to-ref sync faster.
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
# Save policy model (dispatch handles GPU state)
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
# Re-initialize ref model from saved policy (dispatch handles offloading policy first)
self.dispatch.init_model("ref", policy_export_dir)
# Clean up temporary saved model files
try:
shutil.rmtree(policy_export_dir)
logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
except Exception as e:
logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")
logger.info("Successfully updated ref model with policy model, training continues.")attr cfg
cfg = cfgattr colocate_all
colocate_all = cfg.trainer.placement.colocate_allattr tracker
tracker = trackerattr tokenizer
tokenizer = tokenizerattr train_dataset
train_dataset = train_datasetattr eval_dataset
eval_dataset = eval_datasetattr inference_engine_client
inference_engine_client = inference_engine_clientattr generator
generator = generatorattr train_dataloader
train_dataloader = Noneattr total_training_steps
total_training_steps = Noneattr eval_dataloader
eval_dataloader = build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else Noneattr colocate_pg
colocate_pg = colocate_pgattr resume_mode
resume_mode = ResumeMode(cfg.trainer.resume_mode)attr all_metrics
all_metrics = {}attr all_timings
all_timings = {}attr global_step
global_step = 0attr policy_model
policy_model: PPORayActorGroup = Noneattr critic_model
critic_model: Optional[PPORayActorGroup] = Noneattr ref_model
ref_model: Optional[PPORayActorGroup] = Noneattr dynamic_sampling_state
dynamic_sampling_state: Optional[DynamicSamplingState] = Noneattr reward_kl_controller
reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = Noneattr dispatch
dispatch: WorkerDispatch = Noneattr property has_critic
has_critic: boolCheck if critic model is configured.
method async eval
eval() -> Dict[str, float]Run generation and scoring on the evaluation dataset.
The eval metrics are recorded after having finished training self.global_step steps.
Metrics recorded in global_step 0 corresponds to evaluations before training.
Returns:
| Type | Description |
|---|---|
| Dict[str, float] | A dictionary of evaluation metrics. |
Source code in skyrl/train/trainer.py:153-180
@torch.no_grad()
async def eval(self) -> Dict[str, float]:
"""
Run generation and scoring on the evaluation dataset.
The eval metrics are recorded after having finished training `self.global_step` steps.
Metrics recorded in global_step 0 corresponds to evaluations before training.
Returns:
A dictionary of evaluation metrics.
"""
if self.cfg.generator.step_wise_trajectories:
eval_metrics = await evaluate_step_wise(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
else:
eval_metrics = await evaluate(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
return eval_metricsmethod train
train()Main training loop for PPO
Source code in skyrl/train/trainer.py:182-363
async def train(self):
"""
Main training loop for PPO
"""
# Initialize weight sync state between policy model and inference engines.
with Timer("init_weight_sync_state"):
self.init_weight_sync_state()
# Load checkpoint state if resumption is enabled.
if self.resume_mode != ResumeMode.NONE:
with Timer("load_checkpoints"):
self.global_step, _ = self.load_checkpoints()
# Prepare weights for sampling
with Timer("sync_weights"):
await self.dispatch.save_weights_for_sampler()
# Eval before training
if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.tracker.log(eval_metrics, step=self.global_step, commit=True)
# initialize kl controller
if self.cfg.trainer.algorithm.use_kl_in_reward:
self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)
# main training loop
pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
start_epoch = self.global_step // len(self.train_dataloader)
self.global_step += 1 # start training at global_step 1
for epoch in range(start_epoch, self.cfg.trainer.epochs):
for _, rand_prompts in enumerate(self.train_dataloader):
with Timer("step", self.all_timings):
# for colocate_all=true, inference engine is always on GPU when starting the training step
# 0. truncate data to have even shards
rand_prompts = self._remove_tail_data(rand_prompts)
generator_input, uids = prepare_generator_input(
rand_prompts,
self.cfg.generator.n_samples_per_prompt,
get_sampling_params_for_backend(
self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
),
self.cfg.environment.env_class,
"train",
self.global_step,
)
# 1.1. generation phase
with Timer("generate", self.all_timings):
generator_output: GeneratorOutput = await self.generate(generator_input)
if self.cfg.generator.step_wise_trajectories:
# NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
# this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]
# dynamic sampling
if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
if keep_sampling: # continue sampling
# update progress bar for current batch (but not global step)
pbar.update(1)
continue
if self.colocate_all:
# if we are not continuing sampling, we sleep the inference engine
await self.inference_engine_client.sleep()
# 1.2 postprocess rewards (and merge step-wise turns if enabled)
with Timer("postprocess_generator_output", self.all_timings):
generator_output, uids = self.postprocess_generator_output(generator_output, uids)
# 2. print example just for debugging
vis = self.tokenizer.decode(generator_output["response_ids"][0])
log_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)
# 3. Convert GeneratorOutput to TrainingInputBatch
with Timer("convert_to_training_input", self.all_timings):
training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)
# 4. Inference and calculate values, log probs, rewards, kl divergence
with Timer("fwd_logprobs_values_reward", self.all_timings):
training_input = self.fwd_logprobs_values_reward(training_input)
# 5. apply kl divergence penalty to rewards
if self.cfg.trainer.algorithm.use_kl_in_reward:
with Timer("apply_reward_kl_penalty", self.all_timings):
training_input = self.apply_reward_kl_penalty(training_input)
# 6. calculate advantages and returns
with Timer("compute_advantages_and_returns", self.all_timings):
training_input = self.compute_advantages_and_returns(training_input)
# remove some unwanted keys
for key in ["rewards"]:
training_input.pop(key)
training_input.metadata.pop("uids")
training_input.metadata.pop("is_last_step", None)
if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")
# 7. train policy/critic model
# Policy model is backloaded to GPU during training
with Timer("train_critic_and_policy", self.all_timings):
status = self.train_critic_and_policy(training_input)
# 8. conditionally save checkpoints and hf model
is_epoch_end = self.global_step % len(self.train_dataloader) == 0
if self.cfg.trainer.ckpt_interval > 0:
if is_epoch_end or self.global_step % self.cfg.trainer.ckpt_interval == 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
if self.cfg.trainer.hf_save_interval > 0:
if is_epoch_end or self.global_step % self.cfg.trainer.hf_save_interval == 0:
with Timer("save_hf_model", self.all_timings):
self.save_models()
# 9. conditionally sync policy and ref at the end of the epoch
if (
self.cfg.trainer.update_ref_every_epoch
and self.ref_model is not None
and is_epoch_end
and epoch != self.cfg.trainer.epochs - 1 # skip updating ref at the end of the last epoch
):
with Timer("update_ref_with_policy", self.all_timings):
self.update_ref_with_policy()
# 10. Prepare weights for sampling
with Timer("sync_weights", self.all_timings):
await self.dispatch.save_weights_for_sampler()
# 11. set logs
logger.info(status)
# log epoch info
self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
if self.cfg.trainer.eval_interval > 0 and (
self.global_step % self.cfg.trainer.eval_interval == 0
or self.global_step == self.total_training_steps
):
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.all_metrics.update(eval_metrics)
log_payload = {
**self.all_metrics,
**{f"timing/{k}": v for k, v in self.all_timings.items()},
}
self.tracker.log(log_payload, step=self.global_step, commit=True)
self.all_metrics = {}
self.all_timings = {}
# update progress bar after logging
pbar.update(1)
self.global_step += 1
del training_input, generator_output
pbar.close()
if self.colocate_all:
await self.inference_engine_client.sleep()
# Safety net: always save final checkpoint at end of training.
if self.cfg.trainer.ckpt_interval > 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
logger.info("Saved final checkpoint.")
if self.cfg.trainer.hf_save_interval > 0:
with Timer("save_hf_model", self.all_timings):
self.save_models()
logger.info("Saved final model.")
self.tracker.finish()
logger.info("Training done!")method build_models
build_models(PolicyWorker, CriticWorker, RefWorker)Initialize the actors for training, and handle colocation logic
Source code in skyrl/train/trainer.py:389-585
def build_models(self, PolicyWorker, CriticWorker, RefWorker):
"""
Initialize the actors for training, and handle colocation logic
"""
cfg = self.cfg
pg = None
use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward
if cfg.trainer.placement.colocate_all:
num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
ie_cfg = cfg.generator.inference_engine
num_rollout_gpus = (
ie_cfg.num_engines
* ie_cfg.tensor_parallel_size
* ie_cfg.pipeline_parallel_size
* ie_cfg.data_parallel_size
)
assert (
num_policy_gpus == num_rollout_gpus
), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
pg = self.colocate_pg
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
)
if use_ref_model:
assert (
num_policy_gpus == num_ref_gpus
), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
assert (
num_policy_gpus == num_critic_gpus
), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
pg=pg,
num_gpus_per_actor=0.2,
colocate_all=True,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
else:
if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
assert (
cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."
bundles = [
{
"GPU": cfg.trainer.placement.policy_num_gpus_per_node,
"CPU": cfg.trainer.placement.policy_num_gpus_per_node,
}
for _ in range(cfg.trainer.placement.policy_num_nodes)
]
raw_pg = placement_group(bundles, strategy="PACK")
get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
pg = ResolvedPlacementGroup(raw_pg)
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.75 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
)
if use_ref_model:
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.25 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
num_gpus_per_actor=1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
policy_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
critic_steps_per_train_batch = 0
if cfg.trainer.critic.model.path:
critic_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
policy_num_training_steps = (
self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
)
critic_num_training_steps = (
self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
)
if not cfg.trainer.placement.colocate_all:
refs = []
if ref_model is not None:
refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
refs.extend(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
if cfg.trainer.critic.model.path:
refs.extend(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
ray.get(refs)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
else:
if ref_model is not None:
ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
ref_model.offload_to_cpu()
ray.get(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
policy_model.offload_to_cpu()
if cfg.trainer.critic.model.path:
ray.get(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
critic_model.offload_to_cpu()
self.policy_model: PPORayActorGroup = policy_model
self.critic_model: Optional[PPORayActorGroup] = critic_model
self.ref_model: Optional[PPORayActorGroup] = ref_model
# Create unified dispatch that manages all actor groups
self.dispatch = WorkerDispatch(
cfg=self.cfg,
policy_actor_group=policy_model,
critic_actor_group=critic_model,
ref_actor_group=ref_model,
inference_engine_client=self.inference_engine_client,
)
# Mark all models as offloaded if colocate_all (they were offloaded above)
if self.colocate_all:
self.dispatch.mark_all_offloaded()
logger.info("init policy/ref/critic models done")method init_weight_sync_state
init_weight_sync_state()Setup the connection between policy model and inference engine for weight syncing.
Source code in skyrl/train/trainer.py:587-592
def init_weight_sync_state(self):
"""
Setup the connection between policy model and inference engine for weight syncing.
"""
self.dispatch.init_weight_sync_state(self.inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")method sync_policy_weights_to_inference_engines
sync_policy_weights_to_inference_engines() -> List[ObjectRef]Broadcast policy weights to inference engines.
Note: For new code, prefer using dispatch.save_weights_for_sampler() which handles the full weight sync protocol including offload/backload. This method is kept for backward compatibility with subclasses. TODO(tgriggs): Remove this method when migration is complete.
Source code in skyrl/train/trainer.py:594-607
def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:
"""Broadcast policy weights to inference engines.
Note: For new code, prefer using dispatch.save_weights_for_sampler() which
handles the full weight sync protocol including offload/backload.
This method is kept for backward compatibility with subclasses.
TODO(tgriggs): Remove this method when migration is complete.
"""
return self.policy_model.async_run_ray_method(
"pass_through",
"broadcast_to_inference_engines",
self.inference_engine_client,
self.cfg.generator.inference_engine,
)method convert_to_training_input
convert_to_training_input(generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatchConverts lists to a padded batch of tensors for training
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
generator_output | GeneratorOutput | Generated rollouts and associated data. | required |
uids | List[str] | List of prompt-unique identifiers for each generator ouput in the same order as generator_output. Used to identify which prompt each generated rollout belongs to. | required |
Returns:
training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
order of generator_output and hence uids.
Source code in skyrl/train/trainer.py:609-725
def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training
Args:
generator_output (GeneratorOutput): Generated rollouts and associated data.
uids (List[str]): List of prompt-unique identifiers for each generator ouput in the same
order as `generator_output`. Used to identify which prompt each generated rollout belongs to.
Returns:
training_input (TrainingInputBatch): Padded batch of tensors for training. It preserves the
order of `generator_output` and hence `uids`.
"""
# 1. Extract generator output fields.
prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
response_ids: List[List[int]] = generator_output["response_ids"]
rewards: List[List[float]] = generator_output["rewards"]
loss_masks: List[List[int]] = generator_output["loss_masks"]
logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)
rollout_expert_indices: Optional[List[List[List[List[int]]]]] = generator_output.get(
"rollout_expert_indices", None
)
pixel_values = generator_output.get("pixel_values", None)
image_grid_thw = generator_output.get("image_grid_thw", None)
if pixel_values is not None:
assert (
pixel_values is not None and image_grid_thw is not None
), "Both pixel_values and image_grid_thw must exist for multi-modal inputs"
assert len(pixel_values) == len(
image_grid_thw
), "Number of pixel values should match number of image grid thw"
pixel_values = TensorList(pixel_values)
image_grid_thw = TensorList(image_grid_thw)
# 2. Convert to tensors.
(
sequences_tensor,
attention_masks_tensor,
response_masks_tensor,
rewards_tensor,
loss_masks_tensor,
rollout_logprobs_tensor,
rollout_expert_indices_tensor,
) = convert_prompts_responses_to_batch_tensors(
self.tokenizer,
prompt_ids,
response_ids,
rewards,
loss_masks,
logprobs,
rollout_expert_indices,
max_seq_len=self.cfg.trainer.algorithm.max_seq_len,
)
# sanity check for off_policy_correction
off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
tis_ratio_type = off_policy_correction.tis_ratio_type
sequence_mask_metric = off_policy_correction.sequence_mask_metric
if tis_ratio_type is not None or sequence_mask_metric is not None:
assert (
rollout_logprobs_tensor is not None
), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"
# 3. Create training input batch.
training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
"attention_mask": attention_masks_tensor,
"response_mask": response_masks_tensor,
"rewards": rewards_tensor,
"loss_mask": loss_masks_tensor,
"rollout_logprobs": rollout_logprobs_tensor,
"rollout_expert_indices": rollout_expert_indices_tensor,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
},
)
training_input.metadata = {"uids": uids}
if generator_output.get("is_last_step", None) is not None:
training_input.metadata["is_last_step"] = generator_output["is_last_step"]
# 4. Compute mini-batch boundaries for train_critic_and_policy(). It excludes the ones
# we will add in pad_training_input_batch().
train_batch_size = self.cfg.trainer.train_batch_size
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
is_stepwise = self.cfg.generator.step_wise_trajectories
training_input.metadata["policy_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.policy_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
)
if self.cfg.trainer.critic.model.path is not None:
training_input.metadata["critic_mini_batch_boundaries"] = compute_prompt_mini_batch_boundaries(
uids, self.cfg.trainer.critic_mini_batch_size, train_batch_size, is_stepwise, n_samples_per_prompt
)
# 5. Record metadata and metrics.
training_input.metadata["response_length"] = response_masks_tensor.shape[1]
batch_num_seq, batch_padded_seq_len = sequences_tensor.shape
logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}")
self.all_metrics.update(
{
"generate/batch_num_seq": batch_num_seq,
"generate/batch_padded_seq_len": batch_padded_seq_len,
}
)
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)
# 6. Pad the batch, only needed for step-wise training's `fwd_logprobs_values_reward()`.
logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
training_input = pad_training_input_batch(training_input, pad_size)
logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")
return training_inputmethod async generate
generate(input_batch: GeneratorInput) -> GeneratorOutputGenerate rollouts.
If colocate_all is enabled:
- before calling this method, the policy model should be on CPU and inference engine should be awake (i.e. on GPU).
- after calling this method, the same model placement still holds.
Source code in skyrl/train/trainer.py:727-754
@torch.no_grad()
async def generate(
self,
input_batch: GeneratorInput,
) -> GeneratorOutput:
"""
Generate rollouts.
If colocate_all is enabled:
- before calling this method, the policy model should be on CPU and inference engine should
be awake (i.e. on GPU).
- after calling this method, the same model placement still holds.
"""
# NOTE: we assume that .generate returns samples in the same order as passed in
generator_output: GeneratorOutput = await self.generator.generate(input_batch)
# add rollout metrics to self.all_metrics
if generator_output["rollout_metrics"] is not None:
self.all_metrics.update(generator_output["rollout_metrics"])
generator_output.pop("rollout_metrics", None)
validate_generator_output(
len(input_batch["prompts"]),
generator_output,
step_wise=self.cfg.generator.step_wise_trajectories,
)
return generator_outputmethod postprocess_generator_output
postprocess_generator_output(generator_output: GeneratorOutput, uids: List[str]) -> Tuple[GeneratorOutput, List[str]]Converts to per token rewards and computes pass@N.
For step-wise training with merge_stepwise_output=true, also collapses
consecutive turns sharing a common prefix into a single sequence; uids
is shortened to match.
In the future algorithm specific reward or loss mask post processing should be done here.
Returns:
| Type | Description |
|---|---|
| Tuple[GeneratorOutput, List[str]] | (generator_output, uids) — uids may be shorter than the input when merging. |
Source code in skyrl/train/trainer.py:756-844
@torch.no_grad()
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
"""
Converts to per token rewards and computes pass@N.
For step-wise training with ``merge_stepwise_output=true``, also collapses
consecutive turns sharing a common prefix into a single sequence; ``uids``
is shortened to match.
In the future algorithm specific reward or loss mask post processing should be done here.
Returns:
(generator_output, uids) — uids may be shorter than the input when merging.
"""
generator_output_for_metrics = generator_output
uids_for_metrics = uids
if self.cfg.generator.step_wise_trajectories:
generator_output_for_metrics = defaultdict(list)
for key in generator_output:
if isinstance(generator_output[key], list):
generator_output_for_metrics[key] = [
generator_output[key][i]
for i in range(len(generator_output[key]))
if generator_output["is_last_step"][i]
]
uids_for_metrics = [
uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
]
# only use `generator_output_for_metrics` for metrics calculation
# For step-wise training, we only calculate metrics for the last step of each trajectory
overall_metrics = get_metrics_from_generator_output(
generator_output_for_metrics,
uids_for_metrics,
)
# Prefix-aware merging of step-wise turns.
if self.cfg.generator.merge_stepwise_output:
assert self.cfg.generator.step_wise_trajectories, "merge_stepwise_output requires step-wise training"
num_seq_before_merge = len(generator_output["response_ids"])
generator_output = merge_stepwise_output(generator_output)
num_seq_after_merge = len(generator_output["response_ids"])
logger.info(f"Merged step wise: {num_seq_before_merge} sequences -> {num_seq_after_merge} sequences")
self.all_metrics.update(
{
"generate/num_seq_before_merge": num_seq_before_merge,
"generate/num_seq_after_merge": num_seq_after_merge,
}
)
uids = [tid.instance_id for tid in generator_output["trajectory_ids"]]
# these use the full generator output
rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
responses: List[List[int]] = generator_output["response_ids"]
per_token_rewards: List[List[float]] = []
# Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
if rewards and isinstance(rewards[0], list):
# Token-level rewards: rewards is List[List[float]]
per_token_rewards = rewards
else:
if self.cfg.trainer.algorithm.zero_variance_filter:
kept_indices_set = set(zero_variance_filter(rewards, uids))
generator_output["loss_masks"] = [
[0] * len(mask) if i not in kept_indices_set else mask
for i, mask in enumerate(generator_output["loss_masks"])
]
# Response-level rewards: rewards is List[float], convert to per-token rewards
for reward, response in zip(rewards, responses):
per_token_reward = [0.0] * len(response)
per_token_reward[-1] = float(reward)
per_token_rewards.append(per_token_reward)
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
reward_metrics = {
f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
"reward/avg_raw_reward": overall_metrics["avg_score"],
"reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
}
self.all_metrics.update(reward_metrics)
logger.info(
f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
)
# re-assign reward but now it's per token rewards
generator_output["rewards"] = per_token_rewards
return generator_output, uidsmethod compute_advantages_and_returns
compute_advantages_and_returns(data: TrainingInputBatch) -> TrainingInputBatchCalculate advantages and returns for the data batch.
Expects:
["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]["response_mask"]: Integer[torch.Tensor, "batch_size seqlen"]["loss_mask"]: Integer[torch.Tensor, "batch_size seqlen"]["values"]: Float[torch.Tensor, "batch_size seqlen"]["rewards"]: Float[torch.Tensor, "batch_size seqlen"].metadata["uids"]: List[str].metadata["is_last_step"]: List[bool] for step-wise training
Adds:
["advantages"]: Float[torch.Tensor, "batch_size seqlen"]["returns"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:846-957
@torch.no_grad()
def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
"""Calculate advantages and returns for the data batch.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
- `.metadata["uids"]`: List[str]
- `.metadata["is_last_step"]`: List[bool] for step-wise training
Adds:
- `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
token_level_rewards = data["rewards"]
if self.cfg.generator.step_wise_trajectories:
is_last_step = torch.tensor(data.metadata["is_last_step"], dtype=torch.bool)
index = np.array(data.metadata["uids"])
values = data["values"]
# Step-wise only supports outcome-based estimators (GRPO, RLOO, MAXRL); ensured by `validate_cfg`.
# We use the last step of each trajectory to compute advantages and broadcast them to
# all steps of that trajectory, so we ignore per-step rewards in step-wise training.
# We pass an all-ones mask here so the estimator returns the scalar advantage at every
# position. The real per-step `response_mask` is re-applied on broadcast below.
# Shapes:
# traj_ids, (batch_size,): trajectory id per step (cumsum of shifted is_last_step)
# last_step_advantages/returns,
# (num_traj, seqlen): scalar advantage/return per trajectory at every position
# last_step_advantages/returns[traj_ids],
# (batch_size, seqlen): broadcast to every step of the owning trajectory
# response_mask_float,
# (batch_size, seqlen): per-step response mask
last_step_response_mask = data["response_mask"][is_last_step]
last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards[is_last_step],
response_mask=torch.ones_like(last_step_response_mask, dtype=torch.float),
index=index[is_last_step.cpu().numpy()],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
values=values[is_last_step] if values is not None else None,
config=self.cfg.trainer.algorithm,
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
traj_ids = (
torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
)
num_traj = traj_ids[-1].item() + 1
assert num_traj == len(
last_step_advantages
), f"num_traj {num_traj} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
response_mask_float = data["response_mask"].to(last_step_advantages.dtype)
advantages = last_step_advantages[traj_ids] * response_mask_float
returns = last_step_returns[traj_ids] * response_mask_float
else:
advantages, returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
response_mask=data["response_mask"],
index=data.metadata["uids"],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
config=self.cfg.trainer.algorithm,
values=data["values"],
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
data["returns"] = returns
data["advantages"] = advantages
# remove padding while calculating metrics
pad_size = data.metadata.get("pad_size", 0)
num_samples = len(token_level_rewards)
return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
if self.cfg.generator.step_wise_trajectories:
avg_rewards: float = return_sums[is_last_step[: num_samples - pad_size]].mean().item()
else:
avg_rewards: float = return_sums.mean().item()
avg_response_length = data.metadata["avg_response_length"]
data = data.to("cpu")
valid_advantages = torch.masked_select(
data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
)
avg_advantages: float = valid_advantages.mean().item()
avg_advantages_abs: float = valid_advantages.abs().mean().item()
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_final_rewards": avg_rewards,
"avg_response_length": avg_response_length,
"avg_advantages": avg_advantages,
"avg_advantages_abs": avg_advantages_abs,
}
)
logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
self.all_metrics.update(
{
"loss/avg_final_rewards": avg_rewards,
"loss/avg_raw_advantages": avg_advantages,
"loss/avg_raw_advantages_abs": avg_advantages_abs,
}
)
return datamethod dump_data
dump_data(data: TrainingInputBatch, file_name: str)Dump data to pickle file
Source code in skyrl/train/trainer.py:959-965
def dump_data(self, data: TrainingInputBatch, file_name: str):
"""
Dump data to pickle file
"""
data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
data_save_dir.mkdir(parents=True, exist_ok=True)
data.save(data_save_dir / f"{file_name}.pkl")method fwd_logprobs_values_reward
fwd_logprobs_values_reward(training_input: TrainingInputBatch)Calculate values from the critic, log probs from the policy and ref model.
Dispatch handles offload/backload automatically for all colocation configurations.
Expects:
["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]["attention_mask"]: Integer[torch.Tensor, "batch_size seqlen"].metadata["response_length"]: Int
Adds:
["base_action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]["action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]["values"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:967-1044
@torch.no_grad()
def fwd_logprobs_values_reward(
self,
training_input: TrainingInputBatch,
):
"""
Calculate values from the critic, log probs from the policy and ref model.
Dispatch handles offload/backload automatically for all colocation configurations.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `.metadata["response_length"]`: Int
Adds:
- `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
fwd_keys = ["sequences", "attention_mask"]
if training_input.get("rollout_expert_indices") is not None:
fwd_keys.append("rollout_expert_indices")
if training_input.get("pixel_values") is not None:
fwd_keys.append("pixel_values")
if training_input.get("image_grid_thw") is not None:
fwd_keys.append("image_grid_thw")
data_fwd_pass = training_input.select(keys=fwd_keys, metadata_keys=["response_length"])
values = None
base_log_probs = None
action_log_probs = None
# Critic forward (dispatch handles offload/backload automatically)
if self.has_critic:
critic_output = self.dispatch.forward("critic", data_fwd_pass)
values = critic_output["output"]
# Ref forward
if self.ref_model is not None:
ref_output = self.dispatch.forward("ref", data_fwd_pass)
base_log_probs = ref_output["output"]
self.dispatch.empty_cache("ref")
# Policy forward
policy_output = self.dispatch.forward("policy", data_fwd_pass)
action_log_probs = policy_output["output"]
# Empty cache after all forward passes
self.dispatch.empty_cache()
sequences_all: torch.Tensor = training_input["sequences"]
# NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
action_log_probs = action_log_probs[: len(sequences_all)]
values = values[: len(sequences_all)] if values is not None else None
training_input["base_action_log_probs"] = base_log_probs
training_input["action_log_probs"] = action_log_probs
training_input["values"] = values
if training_input.get("rollout_logprobs", None) is not None:
# calculates the difference in probs between inference and trainer components
# only consider response tokens
logprobs_diff = (
training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
- action_log_probs[training_input["loss_mask"] > 0]
).abs()
logprobs_diff_mean = logprobs_diff.mean().item()
logprobs_diff_std = logprobs_diff.std().item()
self.all_metrics.update(
{
"policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
"policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
}
)
return training_inputmethod apply_reward_kl_penalty
apply_reward_kl_penalty(data: TrainingInputBatch) -> TrainingInputBatchApplies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards.
Source code in skyrl/train/trainer.py:1046-1101
def apply_reward_kl_penalty(
self,
data: TrainingInputBatch,
) -> TrainingInputBatch:
"""Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
loss_masks_all: torch.Tensor = data["loss_mask"]
rewards: torch.Tensor = data["rewards"]
base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
action_log_probs: torch.Tensor = data["action_log_probs"]
# single batched computation
with torch.no_grad():
kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl( # type: ignore
action_log_probs,
base_action_log_probs,
loss_mask=loss_masks_all,
kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
)
kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0] # noqa: F821
kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1) # noqa: F821
# NOTE (erictang000): only supporting custom rewards currently
kl_loss_coef = (
self.reward_kl_controller.value
if self.reward_kl_controller is not None
else self.cfg.trainer.algorithm.kl_loss_coef
)
rewards = rewards - kl * max(0, kl_loss_coef)
data["rewards"] = rewards
avg_kl: float = kl_mean.mean().item()
avg_kl_max: float = kl_max.mean().item()
# update the kl controller
if self.reward_kl_controller is not None:
self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0]) # n_steps is just the batch size
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_kl": avg_kl,
"avg_kl_max": avg_kl_max,
"kl_loss_coef": kl_loss_coef,
}
)
self.all_metrics.update(
{
"loss/avg_kl": avg_kl,
"loss/avg_kl_max": avg_kl_max,
"loss/kl_loss_coef": kl_loss_coef,
}
)
return datamethod train_critic_and_policy
train_critic_and_policy(data: TrainingInputBatch)Run the training step for the policy and critic models.
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
Source code in skyrl/train/trainer.py:1182-1208
def train_critic_and_policy(self, data: TrainingInputBatch):
"""
Run the training step for the policy and critic models.
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
"""
data.metadata["global_step"] = self.global_step
critic_status = None
# Unified training interface for both FSDP and Megatron
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)
# Update metrics
if critic_status is not None:
for k, v in critic_status.items():
self.all_metrics.update({f"critic/{k}": v})
for k, v in policy_status.items():
self.all_metrics.update({f"policy/{k}": v})
self.dispatch.empty_cache()
return policy_statusmethod handle_dynamic_sampling
handle_dynamic_sampling(generator_output: GeneratorOutput, uids: List[str]) -> Tuple[GeneratorOutput, List[str], bool]Handle dynamic sampling for the current batch.
Accumulates the generator output and UIDs across batches if we are sampling repeatedly and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch. If we hit the limit of max sample batches, we raise an error.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
generator_output | GeneratorOutput | Current batch generator output | required |
uids | List[str] | Current batch UIDs | required |
Returns:
| Name | Type | Description |
|---|---|---|
processed_output | GeneratorOutput | Filtered generator output |
processed_uids | List[str] | Filtered UIDs |
keep_sampling | bool | Whether to keep sampling |
Source code in skyrl/train/trainer.py:1210-1270
def handle_dynamic_sampling(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str], bool]:
"""
Handle dynamic sampling for the current batch.
Accumulates the generator output and UIDs across batches if we are sampling repeatedly
and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
If we hit the limit of max sample batches, we raise an error.
Args:
generator_output: Current batch generator output
uids: Current batch UIDs
Returns:
processed_output: Filtered generator output
processed_uids: Filtered UIDs
keep_sampling: Whether to keep sampling
"""
# Prepare sampling configuration
max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
dynamic_sampling_config = {
"type": self.cfg.trainer.algorithm.dynamic_sampling.type,
"max_sample_batches": max_sample_batches,
"min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
"train_batch_size": self.cfg.trainer.train_batch_size,
"n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
}
if self.dynamic_sampling_state is None:
self.dynamic_sampling_state: DynamicSamplingState = {
"sample_batch_count": 1,
}
else:
self.dynamic_sampling_state["sample_batch_count"] += 1
# Handle dynamic sampling using utilities
processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
)
# Check max resample limit, and if we hit it, raise an error
if (
keep_sampling
and max_sample_batches > 0
and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
):
raise RuntimeError(
f"Exiting training loop due to hitting dynamic sampling limit for "
f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
f"Please check your data difficulty distribution."
)
# Update state
self.dynamic_sampling_state = updated_state
if not keep_sampling:
# Reset state when sampling is complete
self.dynamic_sampling_state = None
return processed_output, processed_uids, keep_samplingmethod save_checkpoints
save_checkpoints()Save the model, optimizer, and training states to disk.
Dispatch handles offload/backload automatically for all colocation configurations.
Source code in skyrl/train/trainer.py:1281-1330
def save_checkpoints(self):
"""
Save the model, optimizer, and training states to disk.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
# Create global step folder structure
global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
critic_save_dir = os.path.join(global_step_folder, "critic")
io.makedirs(global_step_folder, exist_ok=True)
# Save policy checkpoint (dispatch handles offload/backload)
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save critic checkpoint (if it exists)
if self.has_critic:
self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)
# Save dataloader state
dataloader_save_path = os.path.join(global_step_folder, "data.pt")
try:
dataloader_state_dict = self.train_dataloader.state_dict()
with io.open_file(dataloader_save_path, "wb") as f:
torch.save(dataloader_state_dict, f)
logger.info(f"Saved dataloader state to {dataloader_save_path}")
except Exception as e:
logger.warning(f"Failed to save dataloader state: {e}")
# Save additional trainer state
trainer_state = {
"global_step": self.global_step,
"config": asdict(self.cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking - write this last after all saves succeed
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_checkpoint_file, "w") as f:
f.write(str(self.global_step))
logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")
# Clean up old checkpoints after successful save
with Timer("cleanup_old_checkpoints", self.all_timings):
self._cleanup_old_checkpoints()method load_checkpoints
load_checkpoints() -> Tuple[int, str]Load complete checkpoint state and return the global_step to resume from. Returns 0 if no checkpoint is loaded.
If colocate_all is True, assumes that the policy model is currently on GPU.
Returns:
| Name | Type | Description |
|---|---|---|
global_step | int | The global step to resume from. |
checkpoint_path | str | The path to the checkpoint. |
Source code in skyrl/train/trainer.py:1345-1456
def load_checkpoints(self) -> Tuple[int, str]:
"""
Load complete checkpoint state and return the global_step to resume from.
Returns 0 if no checkpoint is loaded.
If colocate_all is True, assumes that the policy model is currently on GPU.
Returns:
global_step: The global step to resume from.
checkpoint_path: The path to the checkpoint.
"""
checkpoint_path = None
# Check if resumption is enabled
if self.resume_mode == ResumeMode.NONE:
logger.info("Checkpoint resumption disabled, starting training from scratch")
return 0, None
# first, let's get resume_path
elif self.resume_mode == ResumeMode.LATEST:
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_checkpoint_file):
logger.info("No checkpoint found, starting training from scratch")
return 0, None
with io.open_file(latest_checkpoint_file, "r") as f:
ckpt_iteration = int(f.read().strip())
checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
# Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
validate_consistency_for_latest_checkpoint(
self.cfg.trainer.ckpt_path,
ckpt_iteration,
checkpoint_path,
latest_checkpoint_file,
self.cfg.trainer.ckpt_interval,
)
else:
# Get and validate resume path
checkpoint_path = Path(self.cfg.trainer.resume_path)
if not checkpoint_path:
raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")
# Validate that it's a global_step directory
if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
raise ValueError(
f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
)
# Validate that the path exists
if not io.exists(str(checkpoint_path)):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from: {checkpoint_path}")
# Extract global step from checkpoint path
global_step = extract_step_from_path(Path(checkpoint_path))
if global_step == -1:
raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
logger.info(f"Resuming from global_step: {global_step}")
# Define paths for different checkpoint components
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
dataloader_state_path = os.path.join(checkpoint_path, "data.pt")
# Validate that required checkpoint files exist
if not io.exists(trainer_state_path):
raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")
# 1. Load and validate trainer state
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")
# 2. Load dataloader state if available
if io.exists(dataloader_state_path):
try:
with io.open_file(dataloader_state_path, "rb") as f:
dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state)
logger.info("Successfully loaded dataloader state")
except Exception as e:
logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
else:
logger.warning(
f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
)
# 3. Load policy checkpoint (dispatch handles offload/backload)
logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded policy checkpoint")
# 4. Load critic checkpoint if it exists and we have a critic model
if self.has_critic:
logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
self.dispatch.load_checkpoint(
"critic",
critic_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded critic checkpoint")
logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
return global_step, str(checkpoint_path)method save_models
save_models()Save the model parameters in HF format at cfg.trainer.export_path.
Dispatch handles offload/backload automatically for all colocation configurations.
Source code in skyrl/train/trainer.py:1458-1471
def save_models(self):
"""
Save the model parameters in HF format at `cfg.trainer.export_path`.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
if self.has_critic:
critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)
logger.info("Successfully saved model weights.")method update_ref_with_policy
update_ref_with_policy()Update the reference model with the policy model weights (required by some algorithms).
Dispatch handles offload/backload automatically for all colocation configurations. After this method, save_weights_for_sampler() should be called to sync weights.
Source code in skyrl/train/trainer.py:1473-1496
def update_ref_with_policy(self):
"""
Update the reference model with the policy model weights (required by some algorithms).
Dispatch handles offload/backload automatically for all colocation configurations.
After this method, save_weights_for_sampler() should be called to sync weights.
"""
# TODO(tgriggs): Make policy-to-ref sync faster.
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
# Save policy model (dispatch handles GPU state)
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
# Re-initialize ref model from saved policy (dispatch handles offloading policy first)
self.dispatch.init_model("ref", policy_export_dir)
# Clean up temporary saved model files
try:
shutil.rmtree(policy_export_dir)
logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
except Exception as e:
logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")
logger.info("Successfully updated ref model with policy model, training continues.")Dispatch APIs
class Dispatch
Bases: ABC
Base class for dispatch types
Dispatch types are responsible for:
- dispatching method calls to actors handling data sharding if necessary
- collecting results from actors and concatenating results if necessary
- validating arguments for dispatch
Functions:
| Name | Description |
|---|---|
dispatch | Dispatches method calls to the actors with data sharding if necessary. |
async_collect | Collects results from the actors asynchronously in an asyncio-compatible way. |
sync_collect | Collects results from the actors synchronously and returns a TrainingOutputBatch. |
validate_dispatch_args | Validate and process arguments for dispatch. |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:62-99
class Dispatch(ABC):
"""Base class for dispatch types
Dispatch types are responsible for:
- dispatching method calls to actors handling data sharding if necessary
- collecting results from actors and concatenating results if necessary
- validating arguments for dispatch
"""
@classmethod
@abstractmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
"""Dispatches method calls to the actors with data sharding if necessary."""
pass
@classmethod
@abstractmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors asynchronously in an asyncio-compatible way."""
pass
@classmethod
@abstractmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors synchronously and returns a `TrainingOutputBatch`."""
pass
@classmethod
@abstractmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
"""Validate and process arguments for dispatch.
Returns:
Tuple of (args, kwargs) to be passed to dispatch
"""
passattr dispatch
dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]Dispatches method calls to the actors with data sharding if necessary.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:71-75
@classmethod
@abstractmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
"""Dispatches method calls to the actors with data sharding if necessary."""
passmethod abstractmethod async classmethod async_collect
async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Collects results from the actors asynchronously in an asyncio-compatible way.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:77-83
@classmethod
@abstractmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors asynchronously in an asyncio-compatible way."""
passmethod abstractmethod classmethod sync_collect
sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Collects results from the actors synchronously and returns a TrainingOutputBatch.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:85-89
@classmethod
@abstractmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors synchronously and returns a `TrainingOutputBatch`."""
passmethod abstractmethod classmethod validate_dispatch_args
validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]Validate and process arguments for dispatch.
Returns:
| Type | Description |
|---|---|
| Tuple[Tuple, Dict[str, Any]] | Tuple of (args, kwargs) to be passed to dispatch |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:91-99
@classmethod
@abstractmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
"""Validate and process arguments for dispatch.
Returns:
Tuple of (args, kwargs) to be passed to dispatch
"""
passclass MeshDispatch
Bases: Dispatch
Mesh dispatch type to dispatch data to a group of actors along the device mesh.
Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism. The actor method should accept a single argument - the data batch.
For data dispatch:
- The input data is chunked into
dp_sizeequal chunks, wheredp_sizeis the size of data parallelism. - Each actor with the same DP rank processes the same data chunk in parallel.
For data collection:
- Data is collected only from the primary rank of each model/sequence parallel group.
- The primary rank is defined as the rank with (SP=0, TP=0, PP=0).
- The collected chunks are concatenated in order of DP rank to reconstruct the full data.
Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:
- Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk, and all actors with DP rank 1 process the second chunk.
- Data collection: Only two actors contribute to the final output - the primary rank from each DP group: (DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order.
Functions:
| Name | Description |
|---|---|
dispatch | |
async_collect | |
sync_collect | |
stage_chunks | Pre-stage mini-batch chunks into the object store. |
dispatch_from_staged | Dispatch pre-staged per-DP chunks to workers. |
validate_dispatch_args |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:102-255
class MeshDispatch(Dispatch):
"""Mesh dispatch type to dispatch data to a group of actors along the device mesh.
Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism.
The actor method should accept a single argument - the data batch.
For data dispatch:
* The input data is chunked into `dp_size` equal chunks, where `dp_size` is the size of data parallelism.
* Each actor with the same DP rank processes the same data chunk in parallel.
For data collection:
* Data is collected only from the primary rank of each model/sequence parallel group.
* The primary rank is defined as the rank with (SP=0, TP=0, PP=0).
* The collected chunks are concatenated in order of DP rank to reconstruct the full data.
Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:
* Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk,
and all actors with DP rank 1 process the second chunk.
* Data collection: Only two actors contribute to the final output - the primary rank from each DP group:
(DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order.
"""
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
dp_size = actor_infos[0].rank.dp_size
assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
len(data), dp_size
)
chunk_size = len(data) // dp_size
data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)
# Put each unique chunk in object store ONCE to avoid redundant serialization
# when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]
for actor_info in actor_infos:
# Pass ObjectRef instead of data - workers will fetch from object store
chunk_ref = chunk_refs[actor_info.rank.dp]
object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
return object_refs
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
return
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = ray.get(object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
# all should be none
assert all(obj is None for obj in all_objects), "Got a mix of `None` and non-`None` objects"
return
@classmethod
def stage_chunks(
cls,
dp_size: int,
data: TrainingInputBatch,
mini_batch_boundaries: List[Tuple[int, int]],
) -> List[List[ObjectRef]]:
"""Pre-stage mini-batch chunks into the object store.
Each mini-batch is defined by a ``(start, end)`` index pair from mini_batch_boundaries.
Mini-batches are individually padded so that their size is divisible by dp_size, using dummy
entries with ``loss_mask=0`` that do not affect the loss.
Args:
dp_size: Number of data-parallel ranks.
data: Full TrainingInputBatch to slice from.
mini_batch_boundaries: List of ``(start, end)`` index pairs. The i-th mini-batch is
data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]].
Returns:
``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*.
"""
all_chunk_refs: List[List[ObjectRef]] = []
for start, end in mini_batch_boundaries:
mini_batch = data[start:end]
mb_size = end - start
# Pad to make divisible by dp_size. Will only be non-zero for step-wise training.
pad_size = (-mb_size) % dp_size
if pad_size > 0:
mini_batch = pad_training_input_batch(mini_batch, pad_size)
mini_batch_size = len(mini_batch)
assert (
mini_batch_size % dp_size == 0
), f"mini_batch_size % dp_size != 0, got {mini_batch_size} and {dp_size}"
chunk_size = mini_batch_size // dp_size
chunks = mini_batch.chunk(chunk_size)
all_chunk_refs.append([ray.put(chunk) for chunk in chunks])
return all_chunk_refs
@classmethod
def dispatch_from_staged(
cls,
actor_infos: List[ActorInfo],
method: str,
chunk_refs: List[ObjectRef],
**kwargs,
) -> List[ObjectRef]:
"""
Dispatch pre-staged per-DP chunks to workers.
Each worker receives only its own chunk (already in the object
store), avoiding unnecessary deserialization overhead.
Args:
actor_infos: List of actor info objects
method: Name of method to call on workers (receives a single data chunk)
chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_chunks``)
**kwargs: Additional keyword arguments to pass to the method
Returns:
List of ObjectRefs for worker results
"""
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
for actor_info in actor_infos:
chunk_ref = chunk_refs[actor_info.rank.dp]
object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
return object_refs
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# Extract data from either positional arg or kwarg
if args:
data = args[0]
remaining_kwargs = kwargs
elif "data" in kwargs:
data = kwargs.pop("data")
remaining_kwargs = kwargs
else:
raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")
if not isinstance(data, TrainingInputBatch):
raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
# Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
return (data,), remaining_kwargsattr dispatch
dispatch(actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs: TrainingInputBatch) -> List[ObjectRef]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:128-147
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
dp_size = actor_infos[0].rank.dp_size
assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
len(data), dp_size
)
chunk_size = len(data) // dp_size
data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)
# Put each unique chunk in object store ONCE to avoid redundant serialization
# when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]
for actor_info in actor_infos:
# Pass ObjectRef instead of data - workers will fetch from object store
chunk_ref = chunk_refs[actor_info.rank.dp]
object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
return object_refsmethod abstractmethod async classmethod async_collect
async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:149-157
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
returnmethod abstractmethod classmethod sync_collect
sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:159-167
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = ray.get(object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
# all should be none
assert all(obj is None for obj in all_objects), "Got a mix of `None` and non-`None` objects"
returnmethod classmethod stage_chunks
stage_chunks(dp_size: int, data: TrainingInputBatch, mini_batch_boundaries: List[Tuple[int, int]]) -> List[List[ObjectRef]]Pre-stage mini-batch chunks into the object store.
Each mini-batch is defined by a (start, end) index pair from mini_batch_boundaries.
Mini-batches are individually padded so that their size is divisible by dp_size, using dummy
entries with loss_mask=0 that do not affect the loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dp_size | int | Number of data-parallel ranks. | required |
data | TrainingInputBatch | Full TrainingInputBatch to slice from. | required |
mini_batch_boundaries | List[Tuple[int, int]] | List of (start, end) index pairs. The i-th mini-batch is data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]]. | required |
Returns:
| Type | Description |
|---|---|
| List[List[ObjectRef]] | result[i][dp_rank] - ObjectRef for mini-batch i, DP rank dp_rank. |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:169-208
@classmethod
def stage_chunks(
cls,
dp_size: int,
data: TrainingInputBatch,
mini_batch_boundaries: List[Tuple[int, int]],
) -> List[List[ObjectRef]]:
"""Pre-stage mini-batch chunks into the object store.
Each mini-batch is defined by a ``(start, end)`` index pair from mini_batch_boundaries.
Mini-batches are individually padded so that their size is divisible by dp_size, using dummy
entries with ``loss_mask=0`` that do not affect the loss.
Args:
dp_size: Number of data-parallel ranks.
data: Full TrainingInputBatch to slice from.
mini_batch_boundaries: List of ``(start, end)`` index pairs. The i-th mini-batch is
data[mini_batch_boundaries[i][0]:mini_batch_boundaries[i][1]].
Returns:
``result[i][dp_rank]`` - ObjectRef for mini-batch *i*, DP rank *dp_rank*.
"""
all_chunk_refs: List[List[ObjectRef]] = []
for start, end in mini_batch_boundaries:
mini_batch = data[start:end]
mb_size = end - start
# Pad to make divisible by dp_size. Will only be non-zero for step-wise training.
pad_size = (-mb_size) % dp_size
if pad_size > 0:
mini_batch = pad_training_input_batch(mini_batch, pad_size)
mini_batch_size = len(mini_batch)
assert (
mini_batch_size % dp_size == 0
), f"mini_batch_size % dp_size != 0, got {mini_batch_size} and {dp_size}"
chunk_size = mini_batch_size // dp_size
chunks = mini_batch.chunk(chunk_size)
all_chunk_refs.append([ray.put(chunk) for chunk in chunks])
return all_chunk_refsmethod classmethod dispatch_from_staged
dispatch_from_staged(actor_infos: List[ActorInfo], method: str, chunk_refs: List[ObjectRef], **kwargs: List[ObjectRef]) -> List[ObjectRef]Dispatch pre-staged per-DP chunks to workers.
Each worker receives only its own chunk (already in the object store), avoiding unnecessary deserialization overhead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
actor_infos | List[ActorInfo] | List of actor info objects | required |
method | str | Name of method to call on workers (receives a single data chunk) | required |
chunk_refs | List[ObjectRef] | Pre-staged ObjectRefs, one per DP rank (from stage_chunks) | required |
**kwargs | Additional keyword arguments to pass to the method | {} |
Returns:
| Type | Description |
|---|---|
| List[ObjectRef] | List of ObjectRefs for worker results |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:210-238
@classmethod
def dispatch_from_staged(
cls,
actor_infos: List[ActorInfo],
method: str,
chunk_refs: List[ObjectRef],
**kwargs,
) -> List[ObjectRef]:
"""
Dispatch pre-staged per-DP chunks to workers.
Each worker receives only its own chunk (already in the object
store), avoiding unnecessary deserialization overhead.
Args:
actor_infos: List of actor info objects
method: Name of method to call on workers (receives a single data chunk)
chunk_refs: Pre-staged ObjectRefs, one per DP rank (from ``stage_chunks``)
**kwargs: Additional keyword arguments to pass to the method
Returns:
List of ObjectRefs for worker results
"""
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
for actor_info in actor_infos:
chunk_ref = chunk_refs[actor_info.rank.dp]
object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
return object_refsmethod abstractmethod classmethod validate_dispatch_args
validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:240-255
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# Extract data from either positional arg or kwarg
if args:
data = args[0]
remaining_kwargs = kwargs
elif "data" in kwargs:
data = kwargs.pop("data")
remaining_kwargs = kwargs
else:
raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")
if not isinstance(data, TrainingInputBatch):
raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
# Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
return (data,), remaining_kwargsclass PassThroughDispatch
Bases: Dispatch
PassThrough dispatch type to dispatch data to a group of actors without any sharding.
This is useful for cases where we want to run the same method on all the actors. Supports methods with any number of arguments.
Functions:
| Name | Description |
|---|---|
dispatch | |
async_collect | |
sync_collect | |
validate_dispatch_args |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:258-293
class PassThroughDispatch(Dispatch):
"""PassThrough dispatch type to dispatch data to a group of actors without any sharding.
This is useful for cases where we want to run the same method on all the actors.
Supports methods with any number of arguments.
"""
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
return
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
data_batches = ray.get(object_refs)
if len(data_batches) > 0 and data_batches[0] is not None:
assert isinstance(
data_batches[0], TrainingOutputBatch
), "data_batches must be a list of `TrainingOutputBatch` objects"
return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches)
# all should be none
assert all(obj is None for obj in data_batches), "Got a mix of `None` and non-`None` objects"
return
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# no validation needed just pass everything
return args, kwargsattr dispatch
dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:265-267
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]method abstractmethod async classmethod async_collect
async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:269-276
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
returnmethod abstractmethod classmethod sync_collect
sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:278-288
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
data_batches = ray.get(object_refs)
if len(data_batches) > 0 and data_batches[0] is not None:
assert isinstance(
data_batches[0], TrainingOutputBatch
), "data_batches must be a list of `TrainingOutputBatch` objects"
return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches)
# all should be none
assert all(obj is None for obj in data_batches), "Got a mix of `None` and non-`None` objects"
returnmethod abstractmethod classmethod validate_dispatch_args
validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:290-293
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# no validation needed just pass everything
return args, kwargsWorker APIs
The base worker abstraction in SkyRL.
class Worker
Worker(cfg: TrainerConfig, *args: TrainerConfig, **kwargs: TrainerConfig)Bases: DistributedTorchRayActor
Functions:
| Name | Description |
|---|---|
get_node_local_rank | |
init_worker_process_group | |
get_mesh_rank | |
get_gpu_id | |
get_ray_node_id | |
get_master_addr_port | |
init_model | Initialize worker state (model, and optimizer if applicable) on worker. |
empty_cache | Empty GPU memory cache on Worker's CUDA device |
set_algorithm_config | |
offload_to_cpu | Offload all worker state to CPU. |
backload_to_gpu | Backload worker state to GPU |
get_cuda_memory | Get CUDA memory usage on worker's CUDA device. |
save_memory_snapshot | Save a snapshot of memory usage on the Worker's CUDA device. |
init_weight_sync_state | Initialize state for weight syncing with Inference Engine Client |
forward | Run forward pass on the input batch in inference mode. |
Attributes:
| Name | Type | Description |
|---|---|---|
sequence_parallel_size | int | |
record_memory | ||
cfg |
Source code in skyrl/backends/skyrl_train/workers/worker.py:230-398
class Worker(DistributedTorchRayActor):
def __init__(self, cfg: TrainerConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cfg = cfg
self._transfer_strategy_cls = None # Set in init_weight_transfer_communicator
if self.cfg.algorithm.temperature is None:
raise ValueError("`cfg.algorithm.temperature` must be set")
def init_model(self, *args, **kwargs):
"""Initialize worker state (model, and optimizer if applicable) on worker."""
raise NotImplementedError()
def empty_cache(self) -> None:
"""Empty GPU memory cache on Worker's CUDA device"""
torch.cuda.empty_cache()
def set_algorithm_config(self, **kwargs) -> None:
for key, value in kwargs.items():
setattr(self.cfg.algorithm, key, value)
def offload_to_cpu(self, pin_memory=True, non_blocking=True):
"""Offload all worker state to CPU.
After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain
Args:
pin_memory: Whether to use pinned/ paged-locked memory on CPU
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()
def backload_to_gpu(self, non_blocking=True):
"""Backload worker state to GPU
Args:
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()
def get_cuda_memory(self) -> Dict[str, Any]:
"""Get CUDA memory usage on worker's CUDA device."""
torch.cuda.synchronize()
free, total = torch.cuda.mem_get_info()
return {
"allocated": torch.cuda.memory_allocated(),
"reserved": torch.cuda.memory_reserved(),
"free": free,
"total": total,
}
def save_memory_snapshot(self, tag: str = ""):
"""Save a snapshot of memory usage on the Worker's CUDA device.
No-ops if record_memory is False.
Args:
tag: Label for the snapshot (e.g., "forward_backward", "optim_step")
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
if not self.record_memory:
return
# Track snapshot count for unique filenames
if not hasattr(self, "_snapshot_count"):
self._snapshot_count = 0
self._snapshot_count += 1
rank = torch.distributed.get_rank()
save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
if self._local_rank == 0 and not io.exists(save_path):
io.makedirs(save_path, exist_ok=True)
torch.distributed.barrier()
tag_str = f"_{tag}" if tag else ""
file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
record_memory_path = os.path.join(save_path, file_name)
if io.exists(record_memory_path):
# seeing issues if we don't remove the file first
io.remove(record_memory_path)
torch.cuda.memory._dump_snapshot(record_memory_path)
async def init_weight_sync_state(
self,
inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
inference_engine_cfg: "InferenceEngineConfig",
):
"""Initialize state for weight syncing with Inference Engine Client
Creates init info and sender, then sends init info to inference engines
so they can create receivers.
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls
assert inference_engine_client is not None
# Determine transfer strategy based on inference engine config and placement
self._transfer_strategy_cls = get_transfer_strategy_cls(
weight_sync_backend=inference_engine_cfg.weight_sync_backend,
colocate_all=self.cfg.placement.colocate_all,
)
# For new inference path, fetch world_size from servers
# For legacy path, calculate from config
inference_world_size = None
if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
inference_world_size, _ = await inference_engine_client.get_world_size()
# Create init info on all ranks (it's deterministic from cfg or fetched world_size)
init_info = self._transfer_strategy_cls.create_init_info(
inference_engine_cfg, inference_world_size=inference_world_size
)
# Create sender on all ranks
# Strategy implementations may have different logic for different ranks
tasks = [
asyncio.to_thread(
self._transfer_strategy_cls.create_sender,
init_info=init_info,
inference_client=inference_engine_client,
),
]
# Only rank 0 initializes receivers on inference engines
# NOTE: For broadcast strategy, sender and receiver init must run concurrently
# because both need to join the same process group to avoid deadlock
if torch.distributed.get_rank() == 0:
tasks.append(inference_engine_client.init_weight_update_communicator(init_info))
results = await asyncio.gather(*tasks)
self._weight_transfer_sender = results[0] # sender is always first task
# # Register signal handlers for termination only on rank 0
# NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
# The better way is to just have this specified in __del__, but there is
# no guarattee that __del__ will be called in general. Ray also doesn't
# explictly call __del__ when the actor shuts down.
# It's commented out so that we can fix this in the future.
# atexit.register(self._handle_termination)
torch.distributed.barrier()
def forward(
self,
data: TrainingInputBatch,
) -> TrainingOutputBatch:
"""Run forward pass on the input batch in inference mode.
This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.micro_forward_batch_size_per_gpu`.
"""
# run in micro batches of cfg.micro_forward_batch_size_per_gpu
# TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)
outputs = []
for micro_batch in micro_batches:
outputs.append(self._forward_micro_batch(micro_batch))
output = TrainingOutputBatch.cat(outputs)
if output.device is not None and output.device != torch.device("cpu"):
output = output.to("cpu")
return output
def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutputBatch:
raise NotImplementedError()attr sequence_parallel_size
sequence_parallel_size: int = sequence_parallel_sizeattr record_memory
record_memory = record_memoryget_node_local_rank
get_node_local_rank()init_worker_process_group
init_worker_process_group()get_mesh_rank
get_mesh_rank()get_gpu_id
get_gpu_id()get_ray_node_id
get_ray_node_id()get_master_addr_port
get_master_addr_port()attr cfg
cfg = cfgmethod init_model
init_model(*args, **kwargs)Initialize worker state (model, and optimizer if applicable) on worker.
Source code in skyrl/backends/skyrl_train/workers/worker.py:239-241
def init_model(self, *args, **kwargs):
"""Initialize worker state (model, and optimizer if applicable) on worker."""
raise NotImplementedError()method empty_cache
empty_cache() -> NoneEmpty GPU memory cache on Worker's CUDA device
Source code in skyrl/backends/skyrl_train/workers/worker.py:243-245
def empty_cache(self) -> None:
"""Empty GPU memory cache on Worker's CUDA device"""
torch.cuda.empty_cache()method set_algorithm_config
set_algorithm_config(**kwargs) -> NoneSource code in skyrl/backends/skyrl_train/workers/worker.py:247-249
def set_algorithm_config(self, **kwargs) -> None:
for key, value in kwargs.items():
setattr(self.cfg.algorithm, key, value)method offload_to_cpu
offload_to_cpu(pin_memory = True, non_blocking = True)Offload all worker state to CPU.
After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pin_memory | Whether to use pinned/ paged-locked memory on CPU | True | |
non_blocking | Whether the operation is non-blocking | True |
Source code in skyrl/backends/skyrl_train/workers/worker.py:251-260
def offload_to_cpu(self, pin_memory=True, non_blocking=True):
"""Offload all worker state to CPU.
After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain
Args:
pin_memory: Whether to use pinned/ paged-locked memory on CPU
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()method backload_to_gpu
backload_to_gpu(non_blocking = True)Backload worker state to GPU
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
non_blocking | Whether the operation is non-blocking | True |
Source code in skyrl/backends/skyrl_train/workers/worker.py:262-268
def backload_to_gpu(self, non_blocking=True):
"""Backload worker state to GPU
Args:
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()method get_cuda_memory
get_cuda_memory() -> Dict[str, Any]Get CUDA memory usage on worker's CUDA device.
Source code in skyrl/backends/skyrl_train/workers/worker.py:270-279
def get_cuda_memory(self) -> Dict[str, Any]:
"""Get CUDA memory usage on worker's CUDA device."""
torch.cuda.synchronize()
free, total = torch.cuda.mem_get_info()
return {
"allocated": torch.cuda.memory_allocated(),
"reserved": torch.cuda.memory_reserved(),
"free": free,
"total": total,
}method save_memory_snapshot
save_memory_snapshot(tag: str = '')Save a snapshot of memory usage on the Worker's CUDA device.
No-ops if record_memory is False.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tag | str | Label for the snapshot (e.g., "forward_backward", "optim_step") | '' |
.. note:: This function should be called on all the ranks in the worker group simultaneously.
Source code in skyrl/backends/skyrl_train/workers/worker.py:281-312
def save_memory_snapshot(self, tag: str = ""):
"""Save a snapshot of memory usage on the Worker's CUDA device.
No-ops if record_memory is False.
Args:
tag: Label for the snapshot (e.g., "forward_backward", "optim_step")
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
if not self.record_memory:
return
# Track snapshot count for unique filenames
if not hasattr(self, "_snapshot_count"):
self._snapshot_count = 0
self._snapshot_count += 1
rank = torch.distributed.get_rank()
save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
if self._local_rank == 0 and not io.exists(save_path):
io.makedirs(save_path, exist_ok=True)
torch.distributed.barrier()
tag_str = f"_{tag}" if tag else ""
file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
record_memory_path = os.path.join(save_path, file_name)
if io.exists(record_memory_path):
# seeing issues if we don't remove the file first
io.remove(record_memory_path)
torch.cuda.memory._dump_snapshot(record_memory_path)method init_weight_sync_state
init_weight_sync_state(inference_engine_client: Union[InferenceEngineClient, RemoteInferenceClient], inference_engine_cfg: InferenceEngineConfig)Initialize state for weight syncing with Inference Engine Client
Creates init info and sender, then sends init info to inference engines so they can create receivers.
.. note:: This function should be called on all the ranks in the worker group simultaneously.
Source code in skyrl/backends/skyrl_train/workers/worker.py:314-375
async def init_weight_sync_state(
self,
inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
inference_engine_cfg: "InferenceEngineConfig",
):
"""Initialize state for weight syncing with Inference Engine Client
Creates init info and sender, then sends init info to inference engines
so they can create receivers.
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls
assert inference_engine_client is not None
# Determine transfer strategy based on inference engine config and placement
self._transfer_strategy_cls = get_transfer_strategy_cls(
weight_sync_backend=inference_engine_cfg.weight_sync_backend,
colocate_all=self.cfg.placement.colocate_all,
)
# For new inference path, fetch world_size from servers
# For legacy path, calculate from config
inference_world_size = None
if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
inference_world_size, _ = await inference_engine_client.get_world_size()
# Create init info on all ranks (it's deterministic from cfg or fetched world_size)
init_info = self._transfer_strategy_cls.create_init_info(
inference_engine_cfg, inference_world_size=inference_world_size
)
# Create sender on all ranks
# Strategy implementations may have different logic for different ranks
tasks = [
asyncio.to_thread(
self._transfer_strategy_cls.create_sender,
init_info=init_info,
inference_client=inference_engine_client,
),
]
# Only rank 0 initializes receivers on inference engines
# NOTE: For broadcast strategy, sender and receiver init must run concurrently
# because both need to join the same process group to avoid deadlock
if torch.distributed.get_rank() == 0:
tasks.append(inference_engine_client.init_weight_update_communicator(init_info))
results = await asyncio.gather(*tasks)
self._weight_transfer_sender = results[0] # sender is always first task
# # Register signal handlers for termination only on rank 0
# NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
# The better way is to just have this specified in __del__, but there is
# no guarattee that __del__ will be called in general. Ray also doesn't
# explictly call __del__ when the actor shuts down.
# It's commented out so that we can fix this in the future.
# atexit.register(self._handle_termination)
torch.distributed.barrier()method abstractmethod forward
forward(data: TrainingInputBatch) -> TrainingOutputBatchRun forward pass on the input batch in inference mode.
This is a wrapper around _forward_micro_batch that runs in micro batches of cfg.micro_forward_batch_size_per_gpu.
Source code in skyrl/backends/skyrl_train/workers/worker.py:377-395
def forward(
self,
data: TrainingInputBatch,
) -> TrainingOutputBatch:
"""Run forward pass on the input batch in inference mode.
This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.micro_forward_batch_size_per_gpu`.
"""
# run in micro batches of cfg.micro_forward_batch_size_per_gpu
# TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)
outputs = []
for micro_batch in micro_batches:
outputs.append(self._forward_micro_batch(micro_batch))
output = TrainingOutputBatch.cat(outputs)
if output.device is not None and output.device != torch.device("cpu"):
output = output.to("cpu")
return outputclass PPORayActorGroup
PPORayActorGroup(cfg: TrainerConfig, num_nodes: TrainerConfig, num_gpus_per_node: TrainerConfig, ray_actor_type: Type[Worker], pg: Optional[ResolvedPlacementGroup] = None, num_gpus_per_actor: float = 1.0, resources: Optional[Dict[str, float]] = None, num_resources_per_node: Optional[int] = None, colocate_all: bool = False, sequence_parallel_size: int = 1, record_memory: bool = False) -> NoneA group of ray actors Functions start with 'async' should return list of object refs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg | TrainerConfig | config object for workers | required |
num_nodes | int | Number of nodes for this actor group. | required |
num_gpus_per_node | int | Number of gpus for this actor group. | required |
ray_actor_type | Type[Worker] | PPO model type that this actor group serve on. | required |
pg | ResolvedPlacementGroup | Placement group to schedule actor on. If none, create new placement group automatically. Defaults to None. | None |
num_gpus_per_actor | float | Number of gpus allocated for each actor. If < 1.0, multiple models can share same gpu. Defaults to 1. | 1.0 |
Functions:
| Name | Description |
|---|---|
async_init_model | Asynchronously initialize worker state (model, and optimizer if applicable) from model path |
offload_to_cpu | Offload all worker state to CPU. |
backload_to_gpu | Backload worker state to GPU |
run_method | Run a method on all actors using specified dispatch type synchronously. |
async_run_ray_method | Run a method on all actors using specified dispatch type asynchronously. |
async_run_method | Run a method on all actors using specified dispatch type in an asyncio-compatible way. |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
ray_actor_type | ||
colocate_all | ||
sequence_parallel_size | ||
record_memory |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pg | Optional[ResolvedPlacementGroup] | Placement group for the worker group. Accepts a single PlacementGroup, or None. Note that if colocate_all is True, the number of bundles in the placement group must match world_size. | None |
Source code in skyrl/backends/skyrl_train/workers/worker.py:402-669
class PPORayActorGroup:
"""
A group of ray actors
Functions start with 'async' should return list of object refs
Args:
cfg: config object for workers
num_nodes (int): Number of nodes for this actor group.
num_gpus_per_node (int): Number of gpus for this actor group.
ray_actor_type (Type[Worker]): PPO model type that this actor group serve on.
pg (ResolvedPlacementGroup, optional): Placement group to schedule actor on.
If none, create new placement group automatically. Defaults to None.
num_gpus_per_actor (float, optional): Number of gpus allocated for each actor.
If < 1.0, multiple models can share same gpu. Defaults to 1.
"""
def __init__(
self,
cfg: TrainerConfig,
num_nodes,
num_gpus_per_node,
ray_actor_type: Type[Worker],
pg: Optional[ResolvedPlacementGroup] = None,
num_gpus_per_actor: float = 1.0,
resources: Optional[Dict[str, float]] = None,
num_resources_per_node: Optional[int] = None,
colocate_all: bool = False,
sequence_parallel_size: int = 1,
record_memory: bool = False,
) -> None:
"""
Args:
pg: Placement group for the worker group. Accepts a single PlacementGroup, or None.
Note that if colocate_all is True, the number of bundles in the placement group must match world_size.
"""
self.cfg = cfg
self._num_nodes = num_nodes
self._num_gpus_per_node = num_gpus_per_node
self.ray_actor_type = ray_actor_type
# custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html
self._resources = resources
self._num_resources_per_node = num_resources_per_node
self.colocate_all = colocate_all
self.sequence_parallel_size = sequence_parallel_size
self.record_memory = record_memory
self._initiate_actors(pg, num_gpus_per_actor)
def _initiate_actors(self, pg: Optional[ResolvedPlacementGroup], num_gpus_per_actor: float):
"""Initialize Ray actors in the worker group.
Args:
pg: A single placement group for the worker group, or None.
num_gpus_per_actor: The number of gpus to allocate per actor.
"""
world_size = self._num_nodes * self._num_gpus_per_node
# Extract raw Ray PlacementGroup and pre-computed reordered indices from ResolvedPlacementGroup.
# Only use reordered indices when the PG has one bundle per GPU (single-GPU bundles),
# i.e. the bundle count matches world_size. Multi-GPU bundles (whole-node bundles)
# don't need reordering since each bundle already represents a full node.
reordered_bundle_indices = []
raw_pg = None
if pg is not None:
assert isinstance(pg, ResolvedPlacementGroup), f"pg must be a `ResolvedPlacementGroup` got {type(pg)}."
raw_pg = pg.pg
if len(placement_group_table(raw_pg)["bundles"]) == world_size:
reordered_bundle_indices = pg.reordered_bundle_indices
if self.colocate_all:
assert (
raw_pg is not None
), "if colocate_all is True, the shared placement group must be provided to PPORayActorGroup"
pg_data = placement_group_table(raw_pg)
assert len(pg_data["bundles"]) == world_size, (
f"if colocate_all is True, the number of bundles in the placement group "
f"must match world_size. Got {len(pg_data['bundles'])} bundles but world_size={world_size}"
)
# If no PG provided, create one internally
if raw_pg is None and self._num_gpus_per_node > 1:
bundles = [{"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes)]
if self._resources:
resources_name = list(self._resources.keys())[0]
for i in range(len(bundles)):
bundles[i][resources_name] = self._num_resources_per_node
raw_pg = placement_group(bundles, strategy="PACK")
get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
def _scheduling_strategy_for_rank(rank):
if reordered_bundle_indices:
return PlacementGroupSchedulingStrategy(
placement_group=raw_pg,
placement_group_bundle_index=reordered_bundle_indices[rank],
)
elif raw_pg is not None:
return PlacementGroupSchedulingStrategy(
placement_group=raw_pg,
placement_group_bundle_index=rank // self._num_gpus_per_node,
)
# else we are in the single gpu case per node case in which case we don't need to set
# bundle indices
return None
sched = _scheduling_strategy_for_rank(0)
actor_options = {
"num_cpus": num_gpus_per_actor,
"num_gpus": num_gpus_per_actor,
"resources": self._resources,
}
if sched is not None:
actor_options["scheduling_strategy"] = sched
master_actor = self.ray_actor_type.options(**actor_options).remote(
cfg=self.cfg,
world_size=world_size,
rank=0,
local_rank=0,
master_addr=None,
master_port=None,
sequence_parallel_size=self.sequence_parallel_size,
record_memory=self.record_memory,
)
self._actor_handlers = [master_actor]
if world_size > 1:
master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote())
for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node
sched = _scheduling_strategy_for_rank(rank)
actor_options = {
"num_cpus": num_gpus_per_actor,
"num_gpus": num_gpus_per_actor,
"resources": self._resources,
}
if sched is not None:
actor_options["scheduling_strategy"] = sched
worker_actor = self.ray_actor_type.options(**actor_options).remote(
cfg=self.cfg,
world_size=world_size,
rank=rank,
local_rank=local_rank,
master_addr=master_addr,
master_port=master_port,
sequence_parallel_size=self.sequence_parallel_size,
record_memory=self.record_memory,
)
self._actor_handlers.append(worker_actor)
# Initialize process group
logger.info("Initializing process group for RayActorGroup")
ray.get([actor.init_worker_process_group.remote() for actor in self._actor_handlers])
logger.info("Initialized process group for RayActorGroup")
self.actor_infos = [ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) for actor in self._actor_handlers]
logger.info(f"Mesh Ranks: {[actor_info.rank for actor_info in self.actor_infos]}")
def async_init_model(
self,
*args,
**kwargs,
) -> List[ObjectRef]:
"""Asynchronously initialize worker state (model, and optimizer if applicable) from model path
on all the workers.
Returns:
A list of ray object refs.
"""
return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]
def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
"""Offload all worker state to CPU.
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of object refs.
"""
refs = [
actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)
def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
"""Backload worker state to GPU
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of ObjectRefs.
"""
refs = [
actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)
def run_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type synchronously.
The method should either return `None` or a `TrainingOutputBatch` object.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
Collect results from all the actors.
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
# Collect results from all the actors
ret = dispatch_class.sync_collect(self.actor_infos, object_refs)
return ret
def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
"""Run a method on all actors using specified dispatch type asynchronously.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
List of object references
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return object_refs
async def async_run_method(
self, dispatch_type: str, method_name: str, *args, **kwargs
) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type in an asyncio-compatible way.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
TrainingOutputBatch: concatenated results from all actors
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return await dispatch_class.async_collect(self.actor_infos, object_refs)attr cfg
cfg = cfgattr ray_actor_type
ray_actor_type = ray_actor_typeattr colocate_all
colocate_all = colocate_allattr sequence_parallel_size
sequence_parallel_size = sequence_parallel_sizeattr record_memory
record_memory = record_memorymethod async_init_model
async_init_model(*args, **kwargs) -> List[ObjectRef]Asynchronously initialize worker state (model, and optimizer if applicable) from model path on all the workers.
Returns:
| Type | Description |
|---|---|
| List[ObjectRef] | A list of ray object refs. |
Source code in skyrl/backends/skyrl_train/workers/worker.py:562-573
def async_init_model(
self,
*args,
**kwargs,
) -> List[ObjectRef]:
"""Asynchronously initialize worker state (model, and optimizer if applicable) from model path
on all the workers.
Returns:
A list of ray object refs.
"""
return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]method offload_to_cpu
offload_to_cpu(nonblocking = False, offload_optimizer = True, offload_model = True)Offload all worker state to CPU.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nonblocking | Whether this operation is synchronous or asynchronous. | False |
Source code in skyrl/backends/skyrl_train/workers/worker.py:575-588
def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
"""Offload all worker state to CPU.
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of object refs.
"""
refs = [
actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)method backload_to_gpu
backload_to_gpu(nonblocking = False, backload_optimizer = True, backload_model = True)Backload worker state to GPU
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nonblocking | Whether this operation is synchronous or asynchronous. | False |
Source code in skyrl/backends/skyrl_train/workers/worker.py:590-603
def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
"""Backload worker state to GPU
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of ObjectRefs.
"""
refs = [
actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)method run_method
run_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> Optional[TrainingOutputBatch]Run a method on all actors using specified dispatch type synchronously.
The method should either return None or a TrainingOutputBatch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dispatch_type | str | Type of dispatch to use ("mesh" or "pass_through") | required |
method_name | str | Name of the method to call on actors | required |
*args | Positional arguments to pass to the method | () | |
**kwargs | Keyword arguments to pass to the method | {} |
Returns:
| Type | Description |
|---|---|
| Optional[TrainingOutputBatch] | Collect results from all the actors. |
Source code in skyrl/backends/skyrl_train/workers/worker.py:605-627
def run_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type synchronously.
The method should either return `None` or a `TrainingOutputBatch` object.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
Collect results from all the actors.
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
# Collect results from all the actors
ret = dispatch_class.sync_collect(self.actor_infos, object_refs)
return retmethod async_run_ray_method
async_run_ray_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> List[ObjectRef]Run a method on all actors using specified dispatch type asynchronously.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dispatch_type | str | Type of dispatch to use ("mesh" or "pass_through") | required |
method_name | str | Name of the method to call on actors | required |
*args | Positional arguments to pass to the method | () | |
**kwargs | Keyword arguments to pass to the method | {} |
Returns:
| Type | Description |
|---|---|
| List[ObjectRef] | List of object references |
Source code in skyrl/backends/skyrl_train/workers/worker.py:629-647
def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
"""Run a method on all actors using specified dispatch type asynchronously.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
List of object references
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return object_refsmethod async async_run_method
async_run_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> Optional[TrainingOutputBatch]Run a method on all actors using specified dispatch type in an asyncio-compatible way.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dispatch_type | str | Type of dispatch to use ("mesh" or "pass_through") | required |
method_name | str | Name of the method to call on actors | required |
*args | Positional arguments to pass to the method | () | |
**kwargs | Keyword arguments to pass to the method | {} |
Returns:
| Name | Type | Description |
|---|---|---|
TrainingOutputBatch | Optional[TrainingOutputBatch] | concatenated results from all actors |
Source code in skyrl/backends/skyrl_train/workers/worker.py:649-669
async def async_run_method(
self, dispatch_type: str, method_name: str, *args, **kwargs
) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type in an asyncio-compatible way.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
TrainingOutputBatch: concatenated results from all actors
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return await dispatch_class.async_collect(self.actor_infos, object_refs)