Trainer
Trainer API — RayPPOTrainer, Dispatch, Worker APIs.
Trainer Class
class RayPPOTrainer
RayPPOTrainer(cfg: SkyRLTrainConfig, tracker: Tracking, tokenizer: AutoTokenizer, train_dataset: Optional[PromptDataset], inference_engine_client: InferenceEngineClient, generator: GeneratorInterface, colocate_pg: Optional[PlacementGroup] = None, eval_dataset: Optional[PromptDataset] = None)Functions:
| Name | Description |
|---|---|
eval | Run generation and scoring on the evaluation dataset. |
train | Main training loop for PPO |
build_models | Initialize the actors for training, and handle colocation logic |
init_weight_sync_state | Setup the connection between policy model and inference engine for weight syncing. |
sync_policy_weights_to_inference_engines | Broadcast policy weights to inference engines. |
convert_to_training_input | Converts lists to a padded batch of tensors for training |
generate | Generate rollouts. |
postprocess_generator_output | Converts to per token rewards and computes pass@N. |
compute_advantages_and_returns | Calculate advantages and returns for the data batch. |
dump_data | Dump data to pickle file |
pad_batch | Pad the batch to be divisible by dp size |
fwd_logprobs_values_reward | Calculate values from the critic, log probs from the policy and ref model. |
apply_reward_kl_penalty | Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards. |
train_critic_and_policy | Run the training step for the policy and critic models. |
handle_dynamic_sampling | Handle dynamic sampling for the current batch. |
save_checkpoints | Save the model, optimizer, and training states to disk. |
load_checkpoints | Load complete checkpoint state and return the global_step to resume from. |
save_models | Save the model parameters in HF format at cfg.trainer.export_path. |
update_ref_with_policy | Update the reference model with the policy model weights (required by some algorithms). |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
colocate_all | ||
tracker | ||
tokenizer | ||
train_dataset | ||
eval_dataset | ||
inference_engine_client | ||
generator | ||
train_dataloader | ||
total_training_steps | ||
eval_dataloader | ||
colocate_pg | ||
resume_mode | ||
all_metrics | ||
all_timings | ||
global_step | ||
policy_model | PPORayActorGroup | |
critic_model | Optional[PPORayActorGroup] | |
ref_model | Optional[PPORayActorGroup] | |
dynamic_sampling_state | Optional[DynamicSamplingState] | |
reward_kl_controller | Optional[Union[FixedKLController, AdaptiveKLController]] | |
dispatch | WorkerDispatch | |
has_critic | bool | Check if critic model is configured. |
Source code in skyrl/train/trainer.py:82-1410
class RayPPOTrainer:
def __init__(
self,
cfg: SkyRLTrainConfig,
tracker: Tracking,
tokenizer: AutoTokenizer,
train_dataset: Optional[PromptDataset],
inference_engine_client: InferenceEngineClient,
generator: GeneratorInterface,
colocate_pg: Optional[PlacementGroup] = None,
eval_dataset: Optional[PromptDataset] = None,
):
self.cfg = cfg
self.colocate_all = cfg.trainer.placement.colocate_all
self.tracker = tracker
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.inference_engine_client = inference_engine_client
self.generator = generator
self.train_dataloader = None
self.total_training_steps = None
self._build_train_dataloader_and_compute_training_steps()
self.eval_dataloader = (
build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else None
)
self.colocate_pg = colocate_pg
self.resume_mode = ResumeMode(cfg.trainer.resume_mode)
self.all_metrics = {}
self.all_timings = {}
self.global_step = 0
# initialized in `build_models`
self.policy_model: PPORayActorGroup = None
self.critic_model: Optional[PPORayActorGroup] = None
self.ref_model: Optional[PPORayActorGroup] = None
# used for checkpoint cleanup
self._node_ids: Optional[List[str]] = None
self.dynamic_sampling_state: Optional[DynamicSamplingState] = None
self.reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None
self.dispatch: WorkerDispatch = None
configure_ray_worker_logging()
@property
def has_critic(self) -> bool:
"""Check if critic model is configured."""
return bool(self.cfg.trainer.critic.model.path)
def _build_train_dataloader_and_compute_training_steps(self):
"""
Hook for constructing the training dataloader. Subclasses can override
this to customize dataloader behavior. For instance, fully async training
needs a batch size of 1, among other features.
Defaults to `trainer_utils.build_dataloader` with `is_train=True`.
When train_dataset is None (e.g. Tinker backend provides data externally),
the dataloader is not built.
"""
if self.train_dataset is not None:
self.train_dataloader = build_dataloader(self.cfg, self.train_dataset, is_train=True)
self.total_training_steps = len(self.train_dataloader) * self.cfg.trainer.epochs
@torch.no_grad()
async def eval(self) -> Dict[str, float]:
"""
Run generation and scoring on the evaluation dataset.
The eval metrics are recorded after having finished training `self.global_step` steps.
Metrics recorded in global_step 0 corresponds to evaluations before training.
Returns:
A dictionary of evaluation metrics.
"""
if self.cfg.generator.step_wise_trajectories:
eval_metrics = await evaluate_step_wise(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
else:
eval_metrics = await evaluate(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
return eval_metrics
async def train(self):
"""
Main training loop for PPO
"""
# Initialize weight sync state between policy model and inference engines.
with Timer("init_weight_sync_state"):
self.init_weight_sync_state()
# Load checkpoint state if resumption is enabled.
if self.resume_mode != ResumeMode.NONE:
with Timer("load_checkpoints"):
self.global_step, _ = self.load_checkpoints()
# Prepare weights for sampling
with Timer("sync_weights"):
await self.dispatch.save_weights_for_sampler()
# Eval before training
if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.tracker.log(eval_metrics, step=self.global_step, commit=True)
# initialize kl controller
if self.cfg.trainer.algorithm.use_kl_in_reward:
self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)
# main training loop
pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
start_epoch = self.global_step // len(self.train_dataloader)
self.global_step += 1 # start training at global_step 1
for epoch in range(start_epoch, self.cfg.trainer.epochs):
for iter, rand_prompts in enumerate(self.train_dataloader):
with Timer("step", self.all_timings):
# for colocate_all=true, inference engine is always on GPU when starting the training step
# 0. truncate data to have even shards
rand_prompts = self._remove_tail_data(rand_prompts)
generator_input, uids = prepare_generator_input(
rand_prompts,
self.cfg.generator.n_samples_per_prompt,
get_sampling_params_for_backend(
self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
),
self.cfg.environment.env_class,
"train",
self.global_step,
)
# 1.1. generation phase
with Timer("generate", self.all_timings):
generator_output: GeneratorOutput = await self.generate(generator_input)
if self.cfg.generator.step_wise_trajectories:
# NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
# this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]
# dynamic sampling
if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
if keep_sampling: # continue sampling
# update progress bar for current batch (but not global step)
pbar.update(1)
continue
if self.colocate_all:
# if we are not continuing sampling, we sleep the inference engine
await self.inference_engine_client.sleep()
# 1.2 postprocess rewards
with Timer("postprocess_generator_output", self.all_timings):
generator_output = self.postprocess_generator_output(generator_output, uids)
# 2. print example just for debugging
vis = self.tokenizer.decode(generator_output["response_ids"][0])
log_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)
# 3. Convert GeneratorOutput to TrainingInputBatch
with Timer("convert_to_training_input", self.all_timings):
training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)
logger.info(f"Number of sequences: {len(training_input['sequences'])}")
# 4. Inference and calculate values, log probs, rewards, kl divergence
with Timer("fwd_logprobs_values_reward", self.all_timings):
training_input = self.fwd_logprobs_values_reward(training_input)
# 5. apply kl divergence penalty to rewards
if self.cfg.trainer.algorithm.use_kl_in_reward:
with Timer("apply_reward_kl_penalty", self.all_timings):
training_input = self.apply_reward_kl_penalty(training_input)
# 6. calculate advantages and returns
with Timer("compute_advantages_and_returns", self.all_timings):
training_input = self.compute_advantages_and_returns(training_input)
# remove some unwanted keys
for key in ["rewards"]:
training_input.pop(key)
training_input.metadata.pop("uids")
if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)
if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")
# 7. train policy/critic model
# Policy model is backloaded to GPU during training
with Timer("train_critic_and_policy", self.all_timings):
status = self.train_critic_and_policy(training_input)
# 8. conditionally save checkpoints and hf model
if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
if (
self.cfg.trainer.hf_save_interval > 0
and self.global_step % self.cfg.trainer.hf_save_interval == 0
):
with Timer("save_hf_model", self.all_timings):
self.save_models()
# 9. conditionally sync policy and ref at the end of the epoch
if (
self.cfg.trainer.update_ref_every_epoch
and self.ref_model is not None
and iter == len(self.train_dataloader) - 1
and epoch != self.cfg.trainer.epochs - 1 # skip updating ref at the end of the last epoch
):
with Timer("update_ref_with_policy", self.all_timings):
self.update_ref_with_policy()
# 10. Prepare weights for sampling
with Timer("sync_weights", self.all_timings):
await self.dispatch.save_weights_for_sampler()
# 11. set logs
logger.info(status)
# log epoch info
self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
if self.cfg.trainer.eval_interval > 0 and (
self.global_step % self.cfg.trainer.eval_interval == 0
or self.global_step == self.total_training_steps
):
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.all_metrics.update(eval_metrics)
log_payload = {
**self.all_metrics,
**{f"timing/{k}": v for k, v in self.all_timings.items()},
}
self.tracker.log(log_payload, step=self.global_step, commit=True)
self.all_metrics = {}
self.all_timings = {}
# update progress bar after logging
pbar.update(1)
self.global_step += 1
del training_input, generator_output
pbar.close()
if self.colocate_all:
await self.inference_engine_client.sleep()
if self.cfg.trainer.ckpt_interval > 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
logger.info("Saved final checkpoint.")
if self.cfg.trainer.hf_save_interval > 0:
with Timer("save_hf_model", self.all_timings):
self.save_models()
logger.info("Saved final model.")
self.tracker.finish()
logger.info("Training done!")
def _remove_tail_data(self, entries: List[Any]) -> List[Any]:
"""Remove tail data to have even shards in terms of *effective* samples.
Each prompt produces `n_samples_per_prompt` samples. For data-parallel
training we care that the total number of samples is nicely splittable
across the (combined) data-parallel size of all enabled models.
"""
lcm_dp_size = self.dispatch.get_lcm_dp_size()
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
# We want the largest m <= len(entries) such that:
# (m * n_samples_per_prompt) % lcm_dp_size == 0
#
# Let g = gcd(lcm_dp_size, n_samples_per_prompt). Then this is equivalent
# to requiring m to be a multiple of (lcm_dp_size / g).
stride = lcm_dp_size // math.gcd(lcm_dp_size, n_samples_per_prompt)
if stride <= 1:
# Every prompt count is valid, keep all entries.
return entries
kept_prompts = (len(entries) // stride) * stride
return entries[:kept_prompts]
def build_models(self, PolicyWorker, CriticWorker, RefWorker):
"""
Initialize the actors for training, and handle colocation logic
"""
cfg = self.cfg
pg = None
use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward
if cfg.trainer.placement.colocate_all:
num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
ie_cfg = cfg.generator.inference_engine
num_rollout_gpus = (
ie_cfg.num_engines
* ie_cfg.tensor_parallel_size
* ie_cfg.pipeline_parallel_size
* ie_cfg.data_parallel_size
)
assert (
num_policy_gpus == num_rollout_gpus
), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
pg = self.colocate_pg
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
)
if use_ref_model:
assert (
num_policy_gpus == num_ref_gpus
), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
assert (
num_policy_gpus == num_critic_gpus
), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
pg=pg,
num_gpus_per_actor=0.2,
colocate_all=True,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
else:
if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
assert (
cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."
bundles = [
{
"GPU": cfg.trainer.placement.policy_num_gpus_per_node,
"CPU": cfg.trainer.placement.policy_num_gpus_per_node,
}
for _ in range(cfg.trainer.placement.policy_num_nodes)
]
pg = placement_group(bundles, strategy="PACK")
get_ray_pg_ready_with_timeout(pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.75 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
)
if use_ref_model:
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.25 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
num_gpus_per_actor=1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
policy_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
critic_steps_per_train_batch = 0
if cfg.trainer.critic.model.path:
critic_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
policy_num_training_steps = (
self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
)
critic_num_training_steps = (
self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
)
if not cfg.trainer.placement.colocate_all:
refs = []
if ref_model is not None:
refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
refs.extend(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
if cfg.trainer.critic.model.path:
refs.extend(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
ray.get(refs)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
else:
if ref_model is not None:
ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
ref_model.offload_to_cpu()
ray.get(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
policy_model.offload_to_cpu()
if cfg.trainer.critic.model.path:
ray.get(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
critic_model.offload_to_cpu()
self.policy_model: PPORayActorGroup = policy_model
self.critic_model: Optional[PPORayActorGroup] = critic_model
self.ref_model: Optional[PPORayActorGroup] = ref_model
# Create unified dispatch that manages all actor groups
self.dispatch = WorkerDispatch(
cfg=self.cfg,
policy_actor_group=policy_model,
critic_actor_group=critic_model,
ref_actor_group=ref_model,
inference_engine_client=self.inference_engine_client,
)
# Mark all models as offloaded if colocate_all (they were offloaded above)
if self.colocate_all:
self.dispatch.mark_all_offloaded()
logger.info("init policy/ref/critic models done")
def init_weight_sync_state(self):
"""
Setup the connection between policy model and inference engine for weight syncing.
"""
self.dispatch.init_weight_sync_state(self.inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")
def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:
"""Broadcast policy weights to inference engines.
Note: For new code, prefer using dispatch.save_weights_for_sampler() which
handles the full weight sync protocol including offload/backload.
This method is kept for backward compatibility with subclasses.
TODO(tgriggs): Remove this method when migration is complete.
"""
return self.policy_model.async_run_ray_method(
"pass_through",
"broadcast_to_inference_engines",
self.inference_engine_client,
self.cfg.generator.inference_engine,
)
def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training"""
prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
response_ids: List[List[int]] = generator_output["response_ids"]
rewards: List[List[float]] = generator_output["rewards"]
loss_masks: List[List[int]] = generator_output["loss_masks"]
logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)
(
sequences_tensor,
attention_masks_tensor,
response_masks_tensor,
rewards_tensor,
loss_masks_tensor,
rollout_logprobs_tensor,
) = convert_prompts_responses_to_batch_tensors(
self.tokenizer,
prompt_ids,
response_ids,
rewards,
loss_masks,
logprobs,
)
# sanity check for off_policy_correction
off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
tis_ratio_type = off_policy_correction.tis_ratio_type
sequence_mask_metric = off_policy_correction.sequence_mask_metric
if tis_ratio_type is not None or sequence_mask_metric is not None:
assert (
rollout_logprobs_tensor is not None
), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"
training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
"attention_mask": attention_masks_tensor,
"response_mask": response_masks_tensor,
"rewards": rewards_tensor,
"loss_mask": loss_masks_tensor,
"rollout_logprobs": rollout_logprobs_tensor,
"is_last_step": (
torch.tensor(generator_output["is_last_step"], dtype=torch.bool)
if generator_output.get("is_last_step", None) is not None
else None
),
},
)
training_input.metadata = {"uids": uids}
# padded response length
training_input.metadata["response_length"] = response_masks_tensor.shape[1]
if self.cfg.generator.step_wise_trajectories:
assert (
"trajectory_ids" in generator_output
), "Expected `trajectory_ids` in generator output for step wise training"
training_input.metadata["trajectory_ids"] = [
trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]
]
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids)
for sample_response_ids, is_last_step in zip(response_ids, generator_output["is_last_step"])
if is_last_step
) / len(response_ids)
else:
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)
logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
training_input = self.pad_batch(training_input)
logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")
return training_input
@torch.no_grad()
async def generate(
self,
input_batch: GeneratorInput,
) -> GeneratorOutput:
"""
Generate rollouts.
If colocate_all is enabled:
- before calling this method, the policy model should be on CPU and inference engine should
be awake (i.e. on GPU).
- after calling this method, the same model placement still holds.
"""
# NOTE: we assume that .generate returns samples in the same order as passed in
generator_output: GeneratorOutput = await self.generator.generate(input_batch)
# add rollout metrics to self.all_metrics
if generator_output["rollout_metrics"] is not None:
self.all_metrics.update(generator_output["rollout_metrics"])
if not self.cfg.generator.step_wise_trajectories:
validate_generator_output(len(input_batch["prompts"]), generator_output)
return generator_output
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
"""
Converts to per token rewards and computes pass@N.
In the future algorithm specific reward or loss mask post processing should be done here.
"""
generator_output_for_metrics = generator_output
uids_for_metrics = uids
if self.cfg.generator.step_wise_trajectories:
generator_output_for_metrics = defaultdict(list)
for key in generator_output:
if isinstance(generator_output[key], list):
generator_output_for_metrics[key] = [
generator_output[key][i]
for i in range(len(generator_output[key]))
if generator_output["is_last_step"][i]
]
uids_for_metrics = [
uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
]
# only use `generator_output_for_metrics` for metrics calculation
# For step-wise training, we only calculate metrics for the last step of each trajectory
overall_metrics = get_metrics_from_generator_output(
generator_output_for_metrics,
uids_for_metrics,
)
# these use the full generator output
rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
responses: List[List[int]] = generator_output["response_ids"]
per_token_rewards: List[List[float]] = []
# Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
if rewards and isinstance(rewards[0], list):
# Token-level rewards: rewards is List[List[float]]
per_token_rewards = rewards
else:
if self.cfg.trainer.algorithm.zero_variance_filter:
kept_indices_set = set(zero_variance_filter(rewards, uids))
generator_output["loss_masks"] = [
[0] * len(mask) if i not in kept_indices_set else mask
for i, mask in enumerate(generator_output["loss_masks"])
]
# Response-level rewards: rewards is List[float], convert to per-token rewards
for reward, response in zip(rewards, responses):
per_token_reward = [0.0] * len(response)
per_token_reward[-1] = float(reward)
per_token_rewards.append(per_token_reward)
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
reward_metrics = {
f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
"reward/avg_raw_reward": overall_metrics["avg_score"],
"reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
}
self.all_metrics.update(reward_metrics)
logger.info(
f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
)
# re-assign reward but now it's per token rewards
generator_output["rewards"] = per_token_rewards
return generator_output
@torch.no_grad()
def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
"""Calculate advantages and returns for the data batch.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
- `.metadata["uids"]`: List[str]
Adds:
- `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
token_level_rewards = data["rewards"]
if self.cfg.generator.step_wise_trajectories:
is_last_step = data["is_last_step"].bool()
index = np.array(data.metadata["uids"])
values = data["values"]
# Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator
# NOTE(Charlie): so we ignore per-step rewards in step-wise training.
last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards[is_last_step],
response_mask=data["response_mask"][is_last_step],
index=index[is_last_step.cpu().numpy()],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
values=values[is_last_step] if values is not None else None,
config=self.cfg.trainer.algorithm,
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
# Broadcast each trajectory's advantage and return to all steps of each trajectory.
traj_ids = (
torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
)
num_groups = traj_ids[-1].item() + 1
assert num_groups == len(
last_step_advantages
), f"number of groups {num_groups} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
advantages = last_step_advantages[traj_ids]
returns = last_step_returns[traj_ids]
else:
advantages, returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
response_mask=data["response_mask"],
index=data.metadata["uids"],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
config=self.cfg.trainer.algorithm,
values=data["values"],
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
data["returns"] = returns
data["advantages"] = advantages
# remove padding while calculating metrics
pad_size = data.metadata.get("pad_size", 0)
num_samples = len(token_level_rewards)
return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
if self.cfg.generator.step_wise_trajectories:
avg_rewards: float = return_sums[data["is_last_step"][: num_samples - pad_size]].mean().item()
else:
avg_rewards: float = return_sums.mean().item()
avg_response_length = data.metadata["avg_response_length"]
data = data.to("cpu")
valid_advantages = torch.masked_select(
data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
)
avg_advantages: float = valid_advantages.mean().item()
avg_advantages_abs: float = valid_advantages.abs().mean().item()
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_final_rewards": avg_rewards,
"avg_response_length": avg_response_length,
"avg_advantages": avg_advantages,
"avg_advantages_abs": avg_advantages_abs,
}
)
logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
self.all_metrics.update(
{
"loss/avg_final_rewards": avg_rewards,
"loss/avg_raw_advantages": avg_advantages,
"loss/avg_raw_advantages_abs": avg_advantages_abs,
}
)
return data
def dump_data(self, data: TrainingInputBatch, file_name: str):
"""
Dump data to pickle file
"""
data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
data_save_dir.mkdir(parents=True, exist_ok=True)
data.save(data_save_dir / f"{file_name}.pkl")
def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
"""Pad the batch to be divisible by dp size"""
import math
dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
new_tensors = {}
training_input.metadata["pad_size"] = pad_size
if pad_size == 0:
return training_input
for key, tensor in training_input.items():
if tensor is not None:
additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else ()
if key == "is_last_step":
padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
elif key == "loss_mask":
# ensures that padding tensors don't count towards the loss
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
else:
# ensures all padding tensors are in a valid format by cloning `pad_size` from the original input
# `pad_size` is guaranteed to be smaller than batch_size
padding_tensor = tensor[:pad_size].clone()
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
new_training_input = TrainingInputBatch(new_tensors)
new_training_input.metadata = {}
new_training_input.metadata["uids"] = training_input.metadata["uids"] + [f"pad{i}" for i in range(pad_size)]
if "trajectory_ids" in training_input.metadata:
new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [
f"pad{i}" for i in range(pad_size)
]
for key, value in training_input.metadata.items():
if key not in ["uids", "trajectory_ids"]:
new_training_input.metadata[key] = copy.deepcopy(value)
return new_training_input
@torch.no_grad()
def fwd_logprobs_values_reward(
self,
training_input: TrainingInputBatch,
):
"""
Calculate values from the critic, log probs from the policy and ref model.
Dispatch handles offload/backload automatically for all colocation configurations.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `.metadata["response_length"]`: Int
Adds:
- `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
data_fwd_pass = training_input.select(keys=["sequences", "attention_mask"], metadata_keys=["response_length"])
values = None
base_log_probs = None
action_log_probs = None
# Critic forward (dispatch handles offload/backload automatically)
if self.has_critic:
critic_output = self.dispatch.forward("critic", data_fwd_pass)
values = critic_output["output"]
# Ref forward
if self.ref_model is not None:
ref_output = self.dispatch.forward("ref", data_fwd_pass)
base_log_probs = ref_output["output"]
self.dispatch.empty_cache("ref")
# Policy forward
policy_output = self.dispatch.forward("policy", data_fwd_pass)
action_log_probs = policy_output["output"]
# Empty cache after all forward passes
self.dispatch.empty_cache()
sequences_all: torch.Tensor = training_input["sequences"]
# NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
action_log_probs = action_log_probs[: len(sequences_all)]
values = values[: len(sequences_all)] if values is not None else None
training_input["base_action_log_probs"] = base_log_probs
training_input["action_log_probs"] = action_log_probs
training_input["values"] = values
if training_input.get("rollout_logprobs", None) is not None:
# calculates the difference in probs between inference and trainer components
# only consider response tokens
logprobs_diff = (
training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
- action_log_probs[training_input["loss_mask"] > 0]
).abs()
logprobs_diff_mean = logprobs_diff.mean().item()
logprobs_diff_std = logprobs_diff.std().item()
self.all_metrics.update(
{
"policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
"policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
}
)
return training_input
def apply_reward_kl_penalty(
self,
data: TrainingInputBatch,
) -> TrainingInputBatch:
"""Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
loss_masks_all: torch.Tensor = data["loss_mask"]
rewards: torch.Tensor = data["rewards"]
base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
action_log_probs: torch.Tensor = data["action_log_probs"]
# single batched computation
kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl( # type: ignore
action_log_probs,
base_action_log_probs,
loss_mask=loss_masks_all,
kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
)
kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0] # noqa: F821
kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1) # noqa: F821
# NOTE (erictang000): only supporting custom rewards currently
kl_loss_coef = (
self.reward_kl_controller.value
if self.reward_kl_controller is not None
else self.cfg.trainer.algorithm.kl_loss_coef
)
rewards = rewards - kl * max(0, kl_loss_coef)
data["rewards"] = rewards
avg_kl: float = kl_mean.mean().item()
avg_kl_max: float = kl_max.mean().item()
# update the kl controller
if self.reward_kl_controller is not None:
self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0]) # n_steps is just the batch size
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_kl": avg_kl,
"avg_kl_max": avg_kl_max,
"kl_loss_coef": kl_loss_coef,
}
)
self.all_metrics.update(
{
"loss/avg_kl": avg_kl,
"loss/avg_kl_max": avg_kl_max,
"loss/kl_loss_coef": kl_loss_coef,
}
)
return data
def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]:
"""
Execute training step for FSDP strategy using forward_backward + optim_step.
The trainer loops over epochs and mini-batches. Workers handle micro-batching
internally for gradient accumulation (memory efficiency).
Uses staged data approach: the full batch is put in Ray object store once,
and workers fetch + slice locally to avoid repeated serialization.
Args:
model: Model name ("policy" or "critic")
data: Training data batch
Returns:
Dict of reduced metrics from training
"""
# Compute mini batch size from config (algorithm-level concept)
n_samples = self.cfg.generator.n_samples_per_prompt
if model == "policy":
mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples
else:
mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples
all_metrics: Dict[str, List[float]] = defaultdict(list)
# Stage full batch in object store ONCE to avoid repeated serialization
data_ref = self.dispatch.stage_data(data)
# Training loop over epochs and mini-batches
for _epoch in range(self.cfg.trainer.update_epochs_per_batch):
num_mini_batches = len(data) // mini_batch_size
for local_step in range(num_mini_batches):
start_idx = local_step * mini_batch_size
end_idx = (local_step + 1) * mini_batch_size
# Workers fetch from object store and slice locally
status = self.dispatch.forward_backward_from_staged(model, data_ref, start_idx, end_idx)
for k, v in status.items():
all_metrics[k].append(v)
# Optimizer step after each mini batch
grad_norm = self.dispatch.optim_step(model)
if grad_norm is not None:
all_metrics["grad_norm"].append(grad_norm)
# Reduce metrics across all mini-batches and epochs
# pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it
all_metrics.pop("loss_fn_outputs", None)
reduced_metrics = reduce_metrics(all_metrics)
return reduced_metrics
def train_critic_and_policy(self, data: TrainingInputBatch):
"""
Run the training step for the policy and critic models.
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
"""
data.metadata["global_step"] = self.global_step
critic_status = None
# Unified training interface for both FSDP and Megatron
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)
# Update metrics
if critic_status is not None:
for k, v in critic_status.items():
self.all_metrics.update({f"critic/{k}": v})
for k, v in policy_status.items():
self.all_metrics.update({f"policy/{k}": v})
self.dispatch.empty_cache()
return policy_status
def handle_dynamic_sampling(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str], bool]:
"""
Handle dynamic sampling for the current batch.
Accumulates the generator output and UIDs across batches if we are sampling repeatedly
and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
If we hit the limit of max sample batches, we raise an error.
Args:
generator_output: Current batch generator output
uids: Current batch UIDs
Returns:
processed_output: Filtered generator output
processed_uids: Filtered UIDs
keep_sampling: Whether to keep sampling
"""
# Prepare sampling configuration
max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
dynamic_sampling_config = {
"type": self.cfg.trainer.algorithm.dynamic_sampling.type,
"max_sample_batches": max_sample_batches,
"min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
"train_batch_size": self.cfg.trainer.train_batch_size,
"n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
}
if self.dynamic_sampling_state is None:
self.dynamic_sampling_state: DynamicSamplingState = {
"sample_batch_count": 1,
}
else:
self.dynamic_sampling_state["sample_batch_count"] += 1
# Handle dynamic sampling using utilities
processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
)
# Check max resample limit, and if we hit it, raise an error
if (
keep_sampling
and max_sample_batches > 0
and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
):
raise RuntimeError(
f"Exiting training loop due to hitting dynamic sampling limit for "
f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
f"Please check your data difficulty distribution."
)
# Update state
self.dynamic_sampling_state = updated_state
if not keep_sampling:
# Reset state when sampling is complete
self.dynamic_sampling_state = None
return processed_output, processed_uids, keep_sampling
def _get_dp_group_models(self, rank: int, model_type: str = ""):
model = getattr(self, model_type)
return model._actor_handlers[rank]
def _get_mesh_rank(self, rank: int, model_type: str = "") -> MeshRank:
model: PPORayActorGroup = getattr(self, model_type)
actor_info: ActorInfo = model.actor_infos[rank]
return actor_info.rank
def save_checkpoints(self):
"""
Save the model, optimizer, and training states to disk.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
# Create global step folder structure
global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
critic_save_dir = os.path.join(global_step_folder, "critic")
io.makedirs(global_step_folder, exist_ok=True)
# Save policy checkpoint (dispatch handles offload/backload)
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save critic checkpoint (if it exists)
if self.has_critic:
self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)
# Save dataloader state
dataloader_save_path = os.path.join(global_step_folder, "data.pt")
try:
dataloader_state_dict = self.train_dataloader.state_dict()
with io.open_file(dataloader_save_path, "wb") as f:
torch.save(dataloader_state_dict, f)
logger.info(f"Saved dataloader state to {dataloader_save_path}")
except Exception as e:
logger.warning(f"Failed to save dataloader state: {e}")
# Save additional trainer state
trainer_state = {
"global_step": self.global_step,
"config": asdict(self.cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking - write this last after all saves succeed
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_checkpoint_file, "w") as f:
f.write(str(self.global_step))
logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")
# Clean up old checkpoints after successful save
with Timer("cleanup_old_checkpoints", self.all_timings):
self._cleanup_old_checkpoints()
def _cleanup_old_checkpoints(self):
if not self._node_ids:
self._node_ids = self.dispatch.get_node_ids()
run_on_each_node(
self._node_ids,
cleanup_old_checkpoints,
self.cfg.trainer.ckpt_path,
self.cfg.trainer.max_ckpts_to_keep,
)
# run on driver as well
# NOTE (sumanthrh): the function will get called twice on the node with driver process, but it's ok because it's idempotent
cleanup_old_checkpoints(self.cfg.trainer.ckpt_path, self.cfg.trainer.max_ckpts_to_keep)
def load_checkpoints(self) -> Tuple[int, str]:
"""
Load complete checkpoint state and return the global_step to resume from.
Returns 0 if no checkpoint is loaded.
If colocate_all is True, assumes that the policy model is currently on GPU.
Returns:
global_step: The global step to resume from.
checkpoint_path: The path to the checkpoint.
"""
checkpoint_path = None
# Check if resumption is enabled
if self.resume_mode == ResumeMode.NONE:
logger.info("Checkpoint resumption disabled, starting training from scratch")
return 0, None
# first, let's get resume_path
elif self.resume_mode == ResumeMode.LATEST:
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_checkpoint_file):
logger.info("No checkpoint found, starting training from scratch")
return 0, None
with io.open_file(latest_checkpoint_file, "r") as f:
ckpt_iteration = int(f.read().strip())
checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
# Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
validate_consistency_for_latest_checkpoint(
self.cfg.trainer.ckpt_path,
ckpt_iteration,
checkpoint_path,
latest_checkpoint_file,
self.cfg.trainer.ckpt_interval,
)
else:
# Get and validate resume path
checkpoint_path = Path(self.cfg.trainer.resume_path)
if not checkpoint_path:
raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")
# Validate that it's a global_step directory
if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
raise ValueError(
f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
)
# Validate that the path exists
if not io.exists(str(checkpoint_path)):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from: {checkpoint_path}")
# Extract global step from checkpoint path
global_step = extract_step_from_path(Path(checkpoint_path))
if global_step == -1:
raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
logger.info(f"Resuming from global_step: {global_step}")
# Define paths for different checkpoint components
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
dataloader_state_path = os.path.join(checkpoint_path, "data.pt")
# Validate that required checkpoint files exist
if not io.exists(trainer_state_path):
raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")
# 1. Load and validate trainer state
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")
# 2. Load dataloader state if available
if io.exists(dataloader_state_path):
try:
with io.open_file(dataloader_state_path, "rb") as f:
dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state)
logger.info("Successfully loaded dataloader state")
except Exception as e:
logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
else:
logger.warning(
f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
)
# 3. Load policy checkpoint (dispatch handles offload/backload)
logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded policy checkpoint")
# 4. Load critic checkpoint if it exists and we have a critic model
if self.has_critic:
logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
self.dispatch.load_checkpoint(
"critic",
critic_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded critic checkpoint")
logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
return global_step, str(checkpoint_path)
def save_models(self):
"""
Save the model parameters in HF format at `cfg.trainer.export_path`.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
if self.has_critic:
critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)
logger.info("Successfully saved model weights.")
def update_ref_with_policy(self):
"""
Update the reference model with the policy model weights (required by some algorithms).
Dispatch handles offload/backload automatically for all colocation configurations.
After this method, save_weights_for_sampler() should be called to sync weights.
"""
# TODO(tgriggs): Make policy-to-ref sync faster.
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
# Save policy model (dispatch handles GPU state)
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
# Re-initialize ref model from saved policy (dispatch handles offloading policy first)
self.dispatch.init_model("ref", policy_export_dir)
# Clean up temporary saved model files
try:
shutil.rmtree(policy_export_dir)
logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
except Exception as e:
logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")
logger.info("Successfully updated ref model with policy model, training continues.")attr cfg
cfg = cfgattr colocate_all
colocate_all = cfg.trainer.placement.colocate_allattr tracker
tracker = trackerattr tokenizer
tokenizer = tokenizerattr train_dataset
train_dataset = train_datasetattr eval_dataset
eval_dataset = eval_datasetattr inference_engine_client
inference_engine_client = inference_engine_clientattr generator
generator = generatorattr train_dataloader
train_dataloader = Noneattr total_training_steps
total_training_steps = Noneattr eval_dataloader
eval_dataloader = build_dataloader(self.cfg, eval_dataset, is_train=False) if eval_dataset is not None else Noneattr colocate_pg
colocate_pg = colocate_pgattr resume_mode
resume_mode = ResumeMode(cfg.trainer.resume_mode)attr all_metrics
all_metrics = {}attr all_timings
all_timings = {}attr global_step
global_step = 0attr policy_model
policy_model: PPORayActorGroup = Noneattr critic_model
critic_model: Optional[PPORayActorGroup] = Noneattr ref_model
ref_model: Optional[PPORayActorGroup] = Noneattr dynamic_sampling_state
dynamic_sampling_state: Optional[DynamicSamplingState] = Noneattr reward_kl_controller
reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = Noneattr dispatch
dispatch: WorkerDispatch = Noneattr property has_critic
has_critic: boolCheck if critic model is configured.
method async eval
eval() -> Dict[str, float]Run generation and scoring on the evaluation dataset.
The eval metrics are recorded after having finished training self.global_step steps.
Metrics recorded in global_step 0 corresponds to evaluations before training.
Returns:
| Type | Description |
|---|---|
| Dict[str, float] | A dictionary of evaluation metrics. |
Source code in skyrl/train/trainer.py:148-175
@torch.no_grad()
async def eval(self) -> Dict[str, float]:
"""
Run generation and scoring on the evaluation dataset.
The eval metrics are recorded after having finished training `self.global_step` steps.
Metrics recorded in global_step 0 corresponds to evaluations before training.
Returns:
A dictionary of evaluation metrics.
"""
if self.cfg.generator.step_wise_trajectories:
eval_metrics = await evaluate_step_wise(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
else:
eval_metrics = await evaluate(
eval_dataloader=self.eval_dataloader,
generator=self.generator,
cfg=self.cfg,
global_step=self.global_step,
tokenizer=self.tokenizer,
)
return eval_metricsmethod async train
train()Main training loop for PPO
Source code in skyrl/train/trainer.py:177-359
async def train(self):
"""
Main training loop for PPO
"""
# Initialize weight sync state between policy model and inference engines.
with Timer("init_weight_sync_state"):
self.init_weight_sync_state()
# Load checkpoint state if resumption is enabled.
if self.resume_mode != ResumeMode.NONE:
with Timer("load_checkpoints"):
self.global_step, _ = self.load_checkpoints()
# Prepare weights for sampling
with Timer("sync_weights"):
await self.dispatch.save_weights_for_sampler()
# Eval before training
if self.cfg.trainer.eval_interval > 0 and self.cfg.trainer.eval_before_train:
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.tracker.log(eval_metrics, step=self.global_step, commit=True)
# initialize kl controller
if self.cfg.trainer.algorithm.use_kl_in_reward:
self.reward_kl_controller = get_kl_controller(self.cfg.trainer.algorithm)
# main training loop
pbar = tqdm(total=self.total_training_steps, initial=self.global_step, desc="Training Batches Processed")
start_epoch = self.global_step // len(self.train_dataloader)
self.global_step += 1 # start training at global_step 1
for epoch in range(start_epoch, self.cfg.trainer.epochs):
for iter, rand_prompts in enumerate(self.train_dataloader):
with Timer("step", self.all_timings):
# for colocate_all=true, inference engine is always on GPU when starting the training step
# 0. truncate data to have even shards
rand_prompts = self._remove_tail_data(rand_prompts)
generator_input, uids = prepare_generator_input(
rand_prompts,
self.cfg.generator.n_samples_per_prompt,
get_sampling_params_for_backend(
self.cfg.generator.inference_engine.backend, self.cfg.generator.sampling_params
),
self.cfg.environment.env_class,
"train",
self.global_step,
)
# 1.1. generation phase
with Timer("generate", self.all_timings):
generator_output: GeneratorOutput = await self.generate(generator_input)
if self.cfg.generator.step_wise_trajectories:
# NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids`
# this is because in step-wise training, len(uids) != len(generator_output["response_ids"])
uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]]
# dynamic sampling
if self.cfg.trainer.algorithm.dynamic_sampling.type is not None:
generator_output, uids, keep_sampling = self.handle_dynamic_sampling(generator_output, uids)
if keep_sampling: # continue sampling
# update progress bar for current batch (but not global step)
pbar.update(1)
continue
if self.colocate_all:
# if we are not continuing sampling, we sleep the inference engine
await self.inference_engine_client.sleep()
# 1.2 postprocess rewards
with Timer("postprocess_generator_output", self.all_timings):
generator_output = self.postprocess_generator_output(generator_output, uids)
# 2. print example just for debugging
vis = self.tokenizer.decode(generator_output["response_ids"][0])
log_example(
logger,
prompt=generator_input["prompts"][0],
response=vis,
reward=generator_output["rewards"][0],
)
# 3. Convert GeneratorOutput to TrainingInputBatch
with Timer("convert_to_training_input", self.all_timings):
training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids)
logger.info(f"Number of sequences: {len(training_input['sequences'])}")
# 4. Inference and calculate values, log probs, rewards, kl divergence
with Timer("fwd_logprobs_values_reward", self.all_timings):
training_input = self.fwd_logprobs_values_reward(training_input)
# 5. apply kl divergence penalty to rewards
if self.cfg.trainer.algorithm.use_kl_in_reward:
with Timer("apply_reward_kl_penalty", self.all_timings):
training_input = self.apply_reward_kl_penalty(training_input)
# 6. calculate advantages and returns
with Timer("compute_advantages_and_returns", self.all_timings):
training_input = self.compute_advantages_and_returns(training_input)
# remove some unwanted keys
for key in ["rewards"]:
training_input.pop(key)
training_input.metadata.pop("uids")
if self.cfg.trainer.algorithm.advantage_batch_normalize:
training_input = normalize_advantages_dict(training_input)
if self.cfg.trainer.dump_data_batch:
# dump data to file
with Timer("dump_data_batch"):
self.dump_data(training_input, file_name=f"global_step_{self.global_step}_training_input")
# 7. train policy/critic model
# Policy model is backloaded to GPU during training
with Timer("train_critic_and_policy", self.all_timings):
status = self.train_critic_and_policy(training_input)
# 8. conditionally save checkpoints and hf model
if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
if (
self.cfg.trainer.hf_save_interval > 0
and self.global_step % self.cfg.trainer.hf_save_interval == 0
):
with Timer("save_hf_model", self.all_timings):
self.save_models()
# 9. conditionally sync policy and ref at the end of the epoch
if (
self.cfg.trainer.update_ref_every_epoch
and self.ref_model is not None
and iter == len(self.train_dataloader) - 1
and epoch != self.cfg.trainer.epochs - 1 # skip updating ref at the end of the last epoch
):
with Timer("update_ref_with_policy", self.all_timings):
self.update_ref_with_policy()
# 10. Prepare weights for sampling
with Timer("sync_weights", self.all_timings):
await self.dispatch.save_weights_for_sampler()
# 11. set logs
logger.info(status)
# log epoch info
self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
if self.cfg.trainer.eval_interval > 0 and (
self.global_step % self.cfg.trainer.eval_interval == 0
or self.global_step == self.total_training_steps
):
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.all_metrics.update(eval_metrics)
log_payload = {
**self.all_metrics,
**{f"timing/{k}": v for k, v in self.all_timings.items()},
}
self.tracker.log(log_payload, step=self.global_step, commit=True)
self.all_metrics = {}
self.all_timings = {}
# update progress bar after logging
pbar.update(1)
self.global_step += 1
del training_input, generator_output
pbar.close()
if self.colocate_all:
await self.inference_engine_client.sleep()
if self.cfg.trainer.ckpt_interval > 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
logger.info("Saved final checkpoint.")
if self.cfg.trainer.hf_save_interval > 0:
with Timer("save_hf_model", self.all_timings):
self.save_models()
logger.info("Saved final model.")
self.tracker.finish()
logger.info("Training done!")method build_models
build_models(PolicyWorker, CriticWorker, RefWorker)Initialize the actors for training, and handle colocation logic
Source code in skyrl/train/trainer.py:385-580
def build_models(self, PolicyWorker, CriticWorker, RefWorker):
"""
Initialize the actors for training, and handle colocation logic
"""
cfg = self.cfg
pg = None
use_ref_model = cfg.trainer.algorithm.use_kl_loss or cfg.trainer.algorithm.use_kl_in_reward
if cfg.trainer.placement.colocate_all:
num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
num_critic_gpus = cfg.trainer.placement.critic_num_gpus_per_node * cfg.trainer.placement.critic_num_nodes
num_ref_gpus = cfg.trainer.placement.ref_num_gpus_per_node * cfg.trainer.placement.ref_num_nodes
ie_cfg = cfg.generator.inference_engine
num_rollout_gpus = (
ie_cfg.num_engines
* ie_cfg.tensor_parallel_size
* ie_cfg.pipeline_parallel_size
* ie_cfg.data_parallel_size
)
assert (
num_policy_gpus == num_rollout_gpus
), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
pg = self.colocate_pg
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
)
if use_ref_model:
assert (
num_policy_gpus == num_ref_gpus
), "num_policy_gpus and num_ref_gpus must be the same when colocating policy and ref model"
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.2 if pg else 1,
colocate_all=True,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
assert (
num_policy_gpus == num_critic_gpus
), "num_policy_gpus and num_critic_gpus must be the same when colocating policy and critic model"
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
pg=pg,
num_gpus_per_actor=0.2,
colocate_all=True,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
else:
if cfg.trainer.placement.colocate_policy_ref and use_ref_model:
assert (
cfg.trainer.placement.policy_num_nodes == cfg.trainer.placement.ref_num_nodes
and cfg.trainer.placement.policy_num_gpus_per_node == cfg.trainer.placement.ref_num_gpus_per_node
), "num_nodes and num_gpus_per_node must be the same when colocate policy and ref model."
bundles = [
{
"GPU": cfg.trainer.placement.policy_num_gpus_per_node,
"CPU": cfg.trainer.placement.policy_num_gpus_per_node,
}
for _ in range(cfg.trainer.placement.policy_num_nodes)
]
pg = placement_group(bundles, strategy="PACK")
get_ray_pg_ready_with_timeout(pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.75 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
)
if use_ref_model:
ref_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.ref_num_nodes,
cfg.trainer.placement.ref_num_gpus_per_node,
RefWorker,
pg=pg,
num_gpus_per_actor=0.25 if pg else 1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.ref.sequence_parallel_size,
)
else:
ref_model = None
if cfg.trainer.critic.model.path:
critic_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.critic_num_nodes,
cfg.trainer.placement.critic_num_gpus_per_node,
CriticWorker,
num_gpus_per_actor=1,
colocate_all=False,
sequence_parallel_size=cfg.trainer.critic.sequence_parallel_size,
)
else:
critic_model = None
policy_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
critic_steps_per_train_batch = 0
if cfg.trainer.critic.model.path:
critic_steps_per_train_batch = (
cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
)
policy_num_training_steps = (
self.total_training_steps * policy_steps_per_train_batch if self.total_training_steps is not None else None
)
critic_num_training_steps = (
self.total_training_steps * critic_steps_per_train_batch if self.total_training_steps is not None else None
)
if not cfg.trainer.placement.colocate_all:
refs = []
if ref_model is not None:
refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
refs.extend(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
if cfg.trainer.critic.model.path:
refs.extend(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
ray.get(refs)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
else:
if ref_model is not None:
ray.get(ref_model.async_init_model(cfg.trainer.ref.model.path))
ref_model.offload_to_cpu()
ray.get(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
policy_model.offload_to_cpu()
if cfg.trainer.critic.model.path:
ray.get(
critic_model.async_init_model(
cfg.trainer.critic.model.path,
num_training_steps=critic_num_training_steps,
)
)
critic_model.offload_to_cpu()
self.policy_model: PPORayActorGroup = policy_model
self.critic_model: Optional[PPORayActorGroup] = critic_model
self.ref_model: Optional[PPORayActorGroup] = ref_model
# Create unified dispatch that manages all actor groups
self.dispatch = WorkerDispatch(
cfg=self.cfg,
policy_actor_group=policy_model,
critic_actor_group=critic_model,
ref_actor_group=ref_model,
inference_engine_client=self.inference_engine_client,
)
# Mark all models as offloaded if colocate_all (they were offloaded above)
if self.colocate_all:
self.dispatch.mark_all_offloaded()
logger.info("init policy/ref/critic models done")method init_weight_sync_state
init_weight_sync_state()Setup the connection between policy model and inference engine for weight syncing.
Source code in skyrl/train/trainer.py:582-587
def init_weight_sync_state(self):
"""
Setup the connection between policy model and inference engine for weight syncing.
"""
self.dispatch.init_weight_sync_state(self.inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")method sync_policy_weights_to_inference_engines
sync_policy_weights_to_inference_engines() -> List[ObjectRef]Broadcast policy weights to inference engines.
Note: For new code, prefer using dispatch.save_weights_for_sampler() which handles the full weight sync protocol including offload/backload. This method is kept for backward compatibility with subclasses. TODO(tgriggs): Remove this method when migration is complete.
Source code in skyrl/train/trainer.py:589-602
def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:
"""Broadcast policy weights to inference engines.
Note: For new code, prefer using dispatch.save_weights_for_sampler() which
handles the full weight sync protocol including offload/backload.
This method is kept for backward compatibility with subclasses.
TODO(tgriggs): Remove this method when migration is complete.
"""
return self.policy_model.async_run_ray_method(
"pass_through",
"broadcast_to_inference_engines",
self.inference_engine_client,
self.cfg.generator.inference_engine,
)method convert_to_training_input
convert_to_training_input(generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatchConverts lists to a padded batch of tensors for training
Source code in skyrl/train/trainer.py:604-678
def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training"""
prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
response_ids: List[List[int]] = generator_output["response_ids"]
rewards: List[List[float]] = generator_output["rewards"]
loss_masks: List[List[int]] = generator_output["loss_masks"]
logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None)
(
sequences_tensor,
attention_masks_tensor,
response_masks_tensor,
rewards_tensor,
loss_masks_tensor,
rollout_logprobs_tensor,
) = convert_prompts_responses_to_batch_tensors(
self.tokenizer,
prompt_ids,
response_ids,
rewards,
loss_masks,
logprobs,
)
# sanity check for off_policy_correction
off_policy_correction = self.cfg.trainer.algorithm.off_policy_correction
tis_ratio_type = off_policy_correction.tis_ratio_type
sequence_mask_metric = off_policy_correction.sequence_mask_metric
if tis_ratio_type is not None or sequence_mask_metric is not None:
assert (
rollout_logprobs_tensor is not None
), "expected non-null rollout logprobs tensor when off_policy_correction is enabled"
assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses"
training_input = TrainingInputBatch(
{
"sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses)
"attention_mask": attention_masks_tensor,
"response_mask": response_masks_tensor,
"rewards": rewards_tensor,
"loss_mask": loss_masks_tensor,
"rollout_logprobs": rollout_logprobs_tensor,
"is_last_step": (
torch.tensor(generator_output["is_last_step"], dtype=torch.bool)
if generator_output.get("is_last_step", None) is not None
else None
),
},
)
training_input.metadata = {"uids": uids}
# padded response length
training_input.metadata["response_length"] = response_masks_tensor.shape[1]
if self.cfg.generator.step_wise_trajectories:
assert (
"trajectory_ids" in generator_output
), "Expected `trajectory_ids` in generator output for step wise training"
training_input.metadata["trajectory_ids"] = [
trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]
]
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids)
for sample_response_ids, is_last_step in zip(response_ids, generator_output["is_last_step"])
if is_last_step
) / len(response_ids)
else:
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)
logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
training_input = self.pad_batch(training_input)
logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")
return training_inputmethod async generate
generate(input_batch: GeneratorInput) -> GeneratorOutputGenerate rollouts.
If colocate_all is enabled:
- before calling this method, the policy model should be on CPU and inference engine should be awake (i.e. on GPU).
- after calling this method, the same model placement still holds.
Source code in skyrl/train/trainer.py:680-703
@torch.no_grad()
async def generate(
self,
input_batch: GeneratorInput,
) -> GeneratorOutput:
"""
Generate rollouts.
If colocate_all is enabled:
- before calling this method, the policy model should be on CPU and inference engine should
be awake (i.e. on GPU).
- after calling this method, the same model placement still holds.
"""
# NOTE: we assume that .generate returns samples in the same order as passed in
generator_output: GeneratorOutput = await self.generator.generate(input_batch)
# add rollout metrics to self.all_metrics
if generator_output["rollout_metrics"] is not None:
self.all_metrics.update(generator_output["rollout_metrics"])
if not self.cfg.generator.step_wise_trajectories:
validate_generator_output(len(input_batch["prompts"]), generator_output)
return generator_outputmethod postprocess_generator_output
postprocess_generator_output(generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutputConverts to per token rewards and computes pass@N.
In the future algorithm specific reward or loss mask post processing should be done here.
Source code in skyrl/train/trainer.py:705-769
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
"""
Converts to per token rewards and computes pass@N.
In the future algorithm specific reward or loss mask post processing should be done here.
"""
generator_output_for_metrics = generator_output
uids_for_metrics = uids
if self.cfg.generator.step_wise_trajectories:
generator_output_for_metrics = defaultdict(list)
for key in generator_output:
if isinstance(generator_output[key], list):
generator_output_for_metrics[key] = [
generator_output[key][i]
for i in range(len(generator_output[key]))
if generator_output["is_last_step"][i]
]
uids_for_metrics = [
uid for uid, is_last_step in zip(uids, generator_output["is_last_step"]) if is_last_step
]
# only use `generator_output_for_metrics` for metrics calculation
# For step-wise training, we only calculate metrics for the last step of each trajectory
overall_metrics = get_metrics_from_generator_output(
generator_output_for_metrics,
uids_for_metrics,
)
# these use the full generator output
rewards: Union[List[float], List[List[float]]] = generator_output["rewards"]
responses: List[List[int]] = generator_output["response_ids"]
per_token_rewards: List[List[float]] = []
# Check if rewards are already token-level (List[List[float]]) or response-level (List[float])
if rewards and isinstance(rewards[0], list):
# Token-level rewards: rewards is List[List[float]]
per_token_rewards = rewards
else:
if self.cfg.trainer.algorithm.zero_variance_filter:
kept_indices_set = set(zero_variance_filter(rewards, uids))
generator_output["loss_masks"] = [
[0] * len(mask) if i not in kept_indices_set else mask
for i, mask in enumerate(generator_output["loss_masks"])
]
# Response-level rewards: rewards is List[float], convert to per-token rewards
for reward, response in zip(rewards, responses):
per_token_reward = [0.0] * len(response)
per_token_reward[-1] = float(reward)
per_token_rewards.append(per_token_reward)
n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt
reward_metrics = {
f"reward/avg_pass_at_{n_samples_per_prompt}": overall_metrics["pass_at_n"],
"reward/avg_raw_reward": overall_metrics["avg_score"],
"reward/mean_positive_reward": overall_metrics["mean_positive_reward"],
}
self.all_metrics.update(reward_metrics)
logger.info(
f"reward/avg_pass_at_{n_samples_per_prompt}: {overall_metrics['pass_at_n']}, reward/avg_raw_reward: {overall_metrics['avg_score']}, reward/mean_positive_reward: {overall_metrics['mean_positive_reward']}"
)
# re-assign reward but now it's per token rewards
generator_output["rewards"] = per_token_rewards
return generator_outputmethod compute_advantages_and_returns
compute_advantages_and_returns(data: TrainingInputBatch) -> TrainingInputBatchCalculate advantages and returns for the data batch.
Expects:
["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]["response_mask"]: Integer[torch.Tensor, "batch_size seqlen"]["loss_mask"]: Integer[torch.Tensor, "batch_size seqlen"]["values"]: Float[torch.Tensor, "batch_size seqlen"]["rewards"]: Float[torch.Tensor, "batch_size seqlen"].metadata["uids"]: List[str]
Adds:
["advantages"]: Float[torch.Tensor, "batch_size seqlen"]["returns"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:771-869
@torch.no_grad()
def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingInputBatch:
"""Calculate advantages and returns for the data batch.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["response_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["loss_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["rewards"]`: Float[torch.Tensor, "batch_size seqlen"]
- `.metadata["uids"]`: List[str]
Adds:
- `["advantages"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["returns"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
token_level_rewards = data["rewards"]
if self.cfg.generator.step_wise_trajectories:
is_last_step = data["is_last_step"].bool()
index = np.array(data.metadata["uids"])
values = data["values"]
# Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator
# NOTE(Charlie): so we ignore per-step rewards in step-wise training.
last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards[is_last_step],
response_mask=data["response_mask"][is_last_step],
index=index[is_last_step.cpu().numpy()],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
values=values[is_last_step] if values is not None else None,
config=self.cfg.trainer.algorithm,
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
# Broadcast each trajectory's advantage and return to all steps of each trajectory.
traj_ids = (
torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
)
num_groups = traj_ids[-1].item() + 1
assert num_groups == len(
last_step_advantages
), f"number of groups {num_groups} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
advantages = last_step_advantages[traj_ids]
returns = last_step_returns[traj_ids]
else:
advantages, returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
response_mask=data["response_mask"],
index=data.metadata["uids"],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
config=self.cfg.trainer.algorithm,
values=data["values"],
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
data["returns"] = returns
data["advantages"] = advantages
# remove padding while calculating metrics
pad_size = data.metadata.get("pad_size", 0)
num_samples = len(token_level_rewards)
return_sums = token_level_rewards.sum(dim=-1)[: num_samples - pad_size]
if self.cfg.generator.step_wise_trajectories:
avg_rewards: float = return_sums[data["is_last_step"][: num_samples - pad_size]].mean().item()
else:
avg_rewards: float = return_sums.mean().item()
avg_response_length = data.metadata["avg_response_length"]
data = data.to("cpu")
valid_advantages = torch.masked_select(
data["advantages"][: num_samples - pad_size, ...], data["response_mask"][: num_samples - pad_size].bool()
)
avg_advantages: float = valid_advantages.mean().item()
avg_advantages_abs: float = valid_advantages.abs().mean().item()
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_final_rewards": avg_rewards,
"avg_response_length": avg_response_length,
"avg_advantages": avg_advantages,
"avg_advantages_abs": avg_advantages_abs,
}
)
logger.info(f"avg_final_rewards: {avg_rewards}, avg_response_length: {avg_response_length}")
self.all_metrics.update(
{
"loss/avg_final_rewards": avg_rewards,
"loss/avg_raw_advantages": avg_advantages,
"loss/avg_raw_advantages_abs": avg_advantages_abs,
}
)
return datamethod dump_data
dump_data(data: TrainingInputBatch, file_name: str)Dump data to pickle file
Source code in skyrl/train/trainer.py:871-877
def dump_data(self, data: TrainingInputBatch, file_name: str):
"""
Dump data to pickle file
"""
data_save_dir = Path(self.cfg.trainer.export_path) / "dumped_data"
data_save_dir.mkdir(parents=True, exist_ok=True)
data.save(data_save_dir / f"{file_name}.pkl")method pad_batch
pad_batch(training_input: TrainingInputBatch) -> TrainingInputBatchPad the batch to be divisible by dp size
Source code in skyrl/train/trainer.py:879-914
def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
"""Pad the batch to be divisible by dp size"""
import math
dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
new_tensors = {}
training_input.metadata["pad_size"] = pad_size
if pad_size == 0:
return training_input
for key, tensor in training_input.items():
if tensor is not None:
additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else ()
if key == "is_last_step":
padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
elif key == "loss_mask":
# ensures that padding tensors don't count towards the loss
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
else:
# ensures all padding tensors are in a valid format by cloning `pad_size` from the original input
# `pad_size` is guaranteed to be smaller than batch_size
padding_tensor = tensor[:pad_size].clone()
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
new_training_input = TrainingInputBatch(new_tensors)
new_training_input.metadata = {}
new_training_input.metadata["uids"] = training_input.metadata["uids"] + [f"pad{i}" for i in range(pad_size)]
if "trajectory_ids" in training_input.metadata:
new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [
f"pad{i}" for i in range(pad_size)
]
for key, value in training_input.metadata.items():
if key not in ["uids", "trajectory_ids"]:
new_training_input.metadata[key] = copy.deepcopy(value)
return new_training_inputmethod fwd_logprobs_values_reward
fwd_logprobs_values_reward(training_input: TrainingInputBatch)Calculate values from the critic, log probs from the policy and ref model.
Dispatch handles offload/backload automatically for all colocation configurations.
Expects:
["sequences"]: Integer[torch.Tensor, "batch_size seqlen"]["attention_mask"]: Integer[torch.Tensor, "batch_size seqlen"].metadata["response_length"]: Int
Adds:
["base_action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]["action_log_probs"]: Float[torch.Tensor, "batch_size seqlen"]["values"]: Float[torch.Tensor, "batch_size seqlen"]
Source code in skyrl/train/trainer.py:916-986
@torch.no_grad()
def fwd_logprobs_values_reward(
self,
training_input: TrainingInputBatch,
):
"""
Calculate values from the critic, log probs from the policy and ref model.
Dispatch handles offload/backload automatically for all colocation configurations.
Expects:
- `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `["attention_mask"]`: Integer[torch.Tensor, "batch_size seqlen"]
- `.metadata["response_length"]`: Int
Adds:
- `["base_action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["action_log_probs"]`: Float[torch.Tensor, "batch_size seqlen"]
- `["values"]`: Float[torch.Tensor, "batch_size seqlen"]
"""
data_fwd_pass = training_input.select(keys=["sequences", "attention_mask"], metadata_keys=["response_length"])
values = None
base_log_probs = None
action_log_probs = None
# Critic forward (dispatch handles offload/backload automatically)
if self.has_critic:
critic_output = self.dispatch.forward("critic", data_fwd_pass)
values = critic_output["output"]
# Ref forward
if self.ref_model is not None:
ref_output = self.dispatch.forward("ref", data_fwd_pass)
base_log_probs = ref_output["output"]
self.dispatch.empty_cache("ref")
# Policy forward
policy_output = self.dispatch.forward("policy", data_fwd_pass)
action_log_probs = policy_output["output"]
# Empty cache after all forward passes
self.dispatch.empty_cache()
sequences_all: torch.Tensor = training_input["sequences"]
# NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict.
base_log_probs = base_log_probs[: len(sequences_all)] if base_log_probs is not None else None
action_log_probs = action_log_probs[: len(sequences_all)]
values = values[: len(sequences_all)] if values is not None else None
training_input["base_action_log_probs"] = base_log_probs
training_input["action_log_probs"] = action_log_probs
training_input["values"] = values
if training_input.get("rollout_logprobs", None) is not None:
# calculates the difference in probs between inference and trainer components
# only consider response tokens
logprobs_diff = (
training_input["rollout_logprobs"][training_input["loss_mask"] > 0]
- action_log_probs[training_input["loss_mask"] > 0]
).abs()
logprobs_diff_mean = logprobs_diff.mean().item()
logprobs_diff_std = logprobs_diff.std().item()
self.all_metrics.update(
{
"policy/rollout_train_logprobs_abs_diff_mean": logprobs_diff_mean,
"policy/rollout_train_logprobs_abs_diff_std": logprobs_diff_std,
}
)
return training_inputmethod apply_reward_kl_penalty
apply_reward_kl_penalty(data: TrainingInputBatch) -> TrainingInputBatchApplies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards.
Source code in skyrl/train/trainer.py:988-1042
def apply_reward_kl_penalty(
self,
data: TrainingInputBatch,
) -> TrainingInputBatch:
"""Applies a penalty for KL divergence between the policy log probs and the base model log probs to the rewards."""
loss_masks_all: torch.Tensor = data["loss_mask"]
rewards: torch.Tensor = data["rewards"]
base_action_log_probs: torch.Tensor = data["base_action_log_probs"]
action_log_probs: torch.Tensor = data["action_log_probs"]
# single batched computation
kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl( # type: ignore
action_log_probs,
base_action_log_probs,
loss_mask=loss_masks_all,
kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type,
)
kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0] # noqa: F821
kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1) # noqa: F821
# NOTE (erictang000): only supporting custom rewards currently
kl_loss_coef = (
self.reward_kl_controller.value
if self.reward_kl_controller is not None
else self.cfg.trainer.algorithm.kl_loss_coef
)
rewards = rewards - kl * max(0, kl_loss_coef)
data["rewards"] = rewards
avg_kl: float = kl_mean.mean().item()
avg_kl_max: float = kl_max.mean().item()
# update the kl controller
if self.reward_kl_controller is not None:
self.reward_kl_controller.update(current=avg_kl, n_steps=kl.shape[0]) # n_steps is just the batch size
if "metrics" not in data.metadata:
data.metadata["metrics"] = {}
data.metadata["metrics"].update(
{
"avg_kl": avg_kl,
"avg_kl_max": avg_kl_max,
"kl_loss_coef": kl_loss_coef,
}
)
self.all_metrics.update(
{
"loss/avg_kl": avg_kl,
"loss/avg_kl_max": avg_kl_max,
"loss/kl_loss_coef": kl_loss_coef,
}
)
return datamethod train_critic_and_policy
train_critic_and_policy(data: TrainingInputBatch)Run the training step for the policy and critic models.
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
Source code in skyrl/train/trainer.py:1096-1122
def train_critic_and_policy(self, data: TrainingInputBatch):
"""
Run the training step for the policy and critic models.
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
"""
data.metadata["global_step"] = self.global_step
critic_status = None
# Unified training interface for both FSDP and Megatron
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)
# Update metrics
if critic_status is not None:
for k, v in critic_status.items():
self.all_metrics.update({f"critic/{k}": v})
for k, v in policy_status.items():
self.all_metrics.update({f"policy/{k}": v})
self.dispatch.empty_cache()
return policy_statusmethod handle_dynamic_sampling
handle_dynamic_sampling(generator_output: GeneratorOutput, uids: List[str]) -> Tuple[GeneratorOutput, List[str], bool]Handle dynamic sampling for the current batch.
Accumulates the generator output and UIDs across batches if we are sampling repeatedly and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch. If we hit the limit of max sample batches, we raise an error.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
generator_output | GeneratorOutput | Current batch generator output | required |
uids | List[str] | Current batch UIDs | required |
Returns:
| Name | Type | Description |
|---|---|---|
processed_output | GeneratorOutput | Filtered generator output |
processed_uids | List[str] | Filtered UIDs |
keep_sampling | bool | Whether to keep sampling |
Source code in skyrl/train/trainer.py:1124-1184
def handle_dynamic_sampling(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str], bool]:
"""
Handle dynamic sampling for the current batch.
Accumulates the generator output and UIDs across batches if we are sampling repeatedly
and applies the dynamic sampling strategy (i.e. filter, replace) to the current batch.
If we hit the limit of max sample batches, we raise an error.
Args:
generator_output: Current batch generator output
uids: Current batch UIDs
Returns:
processed_output: Filtered generator output
processed_uids: Filtered UIDs
keep_sampling: Whether to keep sampling
"""
# Prepare sampling configuration
max_sample_batches = self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches
dynamic_sampling_config = {
"type": self.cfg.trainer.algorithm.dynamic_sampling.type,
"max_sample_batches": max_sample_batches,
"min_replace_ratio": self.cfg.trainer.algorithm.dynamic_sampling.min_replace_ratio,
"train_batch_size": self.cfg.trainer.train_batch_size,
"n_samples_per_prompt": self.cfg.generator.n_samples_per_prompt,
}
if self.dynamic_sampling_state is None:
self.dynamic_sampling_state: DynamicSamplingState = {
"sample_batch_count": 1,
}
else:
self.dynamic_sampling_state["sample_batch_count"] += 1
# Handle dynamic sampling using utilities
processed_output, processed_uids, keep_sampling, updated_state = trainer_utils.handle_dynamic_sampling(
generator_output, uids, dynamic_sampling_config, self.dynamic_sampling_state
)
# Check max resample limit, and if we hit it, raise an error
if (
keep_sampling
and max_sample_batches > 0
and self.dynamic_sampling_state["sample_batch_count"] >= max_sample_batches
):
raise RuntimeError(
f"Exiting training loop due to hitting dynamic sampling limit for "
f"{self.cfg.trainer.algorithm.dynamic_sampling.type} strategy with "
f"{self.cfg.trainer.algorithm.dynamic_sampling.max_sample_batches} max sample batches. "
f"Please check your data difficulty distribution."
)
# Update state
self.dynamic_sampling_state = updated_state
if not keep_sampling:
# Reset state when sampling is complete
self.dynamic_sampling_state = None
return processed_output, processed_uids, keep_samplingmethod save_checkpoints
save_checkpoints()Save the model, optimizer, and training states to disk.
Dispatch handles offload/backload automatically for all colocation configurations.
Source code in skyrl/train/trainer.py:1195-1244
def save_checkpoints(self):
"""
Save the model, optimizer, and training states to disk.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
# Create global step folder structure
global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
critic_save_dir = os.path.join(global_step_folder, "critic")
io.makedirs(global_step_folder, exist_ok=True)
# Save policy checkpoint (dispatch handles offload/backload)
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save critic checkpoint (if it exists)
if self.has_critic:
self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer)
# Save dataloader state
dataloader_save_path = os.path.join(global_step_folder, "data.pt")
try:
dataloader_state_dict = self.train_dataloader.state_dict()
with io.open_file(dataloader_save_path, "wb") as f:
torch.save(dataloader_state_dict, f)
logger.info(f"Saved dataloader state to {dataloader_save_path}")
except Exception as e:
logger.warning(f"Failed to save dataloader state: {e}")
# Save additional trainer state
trainer_state = {
"global_step": self.global_step,
"config": asdict(self.cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking - write this last after all saves succeed
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_checkpoint_file, "w") as f:
f.write(str(self.global_step))
logger.info(f"Successfully saved checkpoint for global_step_{self.global_step} to: {global_step_folder}")
# Clean up old checkpoints after successful save
with Timer("cleanup_old_checkpoints", self.all_timings):
self._cleanup_old_checkpoints()method load_checkpoints
load_checkpoints() -> Tuple[int, str]Load complete checkpoint state and return the global_step to resume from. Returns 0 if no checkpoint is loaded.
If colocate_all is True, assumes that the policy model is currently on GPU.
Returns:
| Name | Type | Description |
|---|---|---|
global_step | int | The global step to resume from. |
checkpoint_path | str | The path to the checkpoint. |
Source code in skyrl/train/trainer.py:1259-1370
def load_checkpoints(self) -> Tuple[int, str]:
"""
Load complete checkpoint state and return the global_step to resume from.
Returns 0 if no checkpoint is loaded.
If colocate_all is True, assumes that the policy model is currently on GPU.
Returns:
global_step: The global step to resume from.
checkpoint_path: The path to the checkpoint.
"""
checkpoint_path = None
# Check if resumption is enabled
if self.resume_mode == ResumeMode.NONE:
logger.info("Checkpoint resumption disabled, starting training from scratch")
return 0, None
# first, let's get resume_path
elif self.resume_mode == ResumeMode.LATEST:
latest_checkpoint_file = os.path.join(self.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_checkpoint_file):
logger.info("No checkpoint found, starting training from scratch")
return 0, None
with io.open_file(latest_checkpoint_file, "r") as f:
ckpt_iteration = int(f.read().strip())
checkpoint_path = os.path.join(self.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_iteration}")
# Run validation: Make sure ckpt folder is consistent with latest_ckpt_global_step.txt
validate_consistency_for_latest_checkpoint(
self.cfg.trainer.ckpt_path,
ckpt_iteration,
checkpoint_path,
latest_checkpoint_file,
self.cfg.trainer.ckpt_interval,
)
else:
# Get and validate resume path
checkpoint_path = Path(self.cfg.trainer.resume_path)
if not checkpoint_path:
raise ValueError("`trainer.resume_path` must be specified when resume_mode is 'from_path'")
# Validate that it's a global_step directory
if GLOBAL_STEP_PREFIX not in checkpoint_path.name:
raise ValueError(
f"`trainer.resume_path` must point to a directory whose name starting with {GLOBAL_STEP_PREFIX}, got: {checkpoint_path}"
)
# Validate that the path exists
if not io.exists(str(checkpoint_path)):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from: {checkpoint_path}")
# Extract global step from checkpoint path
global_step = extract_step_from_path(Path(checkpoint_path))
if global_step == -1:
raise ValueError(f"Checkpoint path {checkpoint_path} is not a valid checkpoint path")
logger.info(f"Resuming from global_step: {global_step}")
# Define paths for different checkpoint components
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
critic_ckpt_dir = os.path.join(checkpoint_path, "critic")
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
dataloader_state_path = os.path.join(checkpoint_path, "data.pt")
# Validate that required checkpoint files exist
if not io.exists(trainer_state_path):
raise FileNotFoundError(f"Trainer state file not found: {trainer_state_path}")
# 1. Load and validate trainer state
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value.")
# 2. Load dataloader state if available
if io.exists(dataloader_state_path):
try:
with io.open_file(dataloader_state_path, "rb") as f:
dataloader_state = torch.load(f, map_location="cpu", weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state)
logger.info("Successfully loaded dataloader state")
except Exception as e:
logger.warning(f"Failed to load dataloader state: {e}. Dataloader will start from beginning.")
else:
logger.warning(
f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning."
)
# 3. Load policy checkpoint (dispatch handles offload/backload)
logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded policy checkpoint")
# 4. Load critic checkpoint if it exists and we have a critic model
if self.has_critic:
logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}")
self.dispatch.load_checkpoint(
"critic",
critic_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded critic checkpoint")
logger.info(f"Successfully loaded complete checkpoint state from global_step_{global_step}")
return global_step, str(checkpoint_path)method save_models
save_models()Save the model parameters in HF format at cfg.trainer.export_path.
Dispatch handles offload/backload automatically for all colocation configurations.
Source code in skyrl/train/trainer.py:1372-1385
def save_models(self):
"""
Save the model parameters in HF format at `cfg.trainer.export_path`.
Dispatch handles offload/backload automatically for all colocation configurations.
"""
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
if self.has_critic:
critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic")
self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer)
logger.info("Successfully saved model weights.")method update_ref_with_policy
update_ref_with_policy()Update the reference model with the policy model weights (required by some algorithms).
Dispatch handles offload/backload automatically for all colocation configurations. After this method, save_weights_for_sampler() should be called to sync weights.
Source code in skyrl/train/trainer.py:1387-1410
def update_ref_with_policy(self):
"""
Update the reference model with the policy model weights (required by some algorithms).
Dispatch handles offload/backload automatically for all colocation configurations.
After this method, save_weights_for_sampler() should be called to sync weights.
"""
# TODO(tgriggs): Make policy-to-ref sync faster.
policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy")
# Save policy model (dispatch handles GPU state)
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
# Re-initialize ref model from saved policy (dispatch handles offloading policy first)
self.dispatch.init_model("ref", policy_export_dir)
# Clean up temporary saved model files
try:
shutil.rmtree(policy_export_dir)
logger.info(f"Cleaned up temporary policy export directory: {policy_export_dir}")
except Exception as e:
logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}")
logger.info("Successfully updated ref model with policy model, training continues.")Dispatch APIs
class Dispatch
Bases: ABC
Base class for dispatch types
Dispatch types are responsible for:
- dispatching method calls to actors handling data sharding if necessary
- collecting results from actors and concatenating results if necessary
- validating arguments for dispatch
Functions:
| Name | Description |
|---|---|
dispatch | Dispatches method calls to the actors with data sharding if necessary. |
async_collect | Collects results from the actors asynchronously in an asyncio-compatible way. |
sync_collect | Collects results from the actors synchronously and returns a TrainingOutputBatch. |
validate_dispatch_args | Validate and process arguments for dispatch. |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:61-98
class Dispatch(ABC):
"""Base class for dispatch types
Dispatch types are responsible for:
- dispatching method calls to actors handling data sharding if necessary
- collecting results from actors and concatenating results if necessary
- validating arguments for dispatch
"""
@classmethod
@abstractmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
"""Dispatches method calls to the actors with data sharding if necessary."""
pass
@classmethod
@abstractmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors asynchronously in an asyncio-compatible way."""
pass
@classmethod
@abstractmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors synchronously and returns a `TrainingOutputBatch`."""
pass
@classmethod
@abstractmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
"""Validate and process arguments for dispatch.
Returns:
Tuple of (args, kwargs) to be passed to dispatch
"""
passattr dispatch
dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]Dispatches method calls to the actors with data sharding if necessary.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:70-74
@classmethod
@abstractmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
"""Dispatches method calls to the actors with data sharding if necessary."""
passmethod abstractmethod async classmethod async_collect
async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Collects results from the actors asynchronously in an asyncio-compatible way.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:76-82
@classmethod
@abstractmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors asynchronously in an asyncio-compatible way."""
passmethod abstractmethod classmethod sync_collect
sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Collects results from the actors synchronously and returns a TrainingOutputBatch.
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:84-88
@classmethod
@abstractmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
"""Collects results from the actors synchronously and returns a `TrainingOutputBatch`."""
passmethod abstractmethod classmethod validate_dispatch_args
validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]Validate and process arguments for dispatch.
Returns:
| Type | Description |
|---|---|
| Tuple[Tuple, Dict[str, Any]] | Tuple of (args, kwargs) to be passed to dispatch |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:90-98
@classmethod
@abstractmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
"""Validate and process arguments for dispatch.
Returns:
Tuple of (args, kwargs) to be passed to dispatch
"""
passclass MeshDispatch
Bases: Dispatch
Mesh dispatch type to dispatch data to a group of actors along the device mesh.
Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism. The actor method should accept a single argument - the data batch.
For data dispatch:
- The input data is chunked into
dp_sizeequal chunks, wheredp_sizeis the size of data parallelism. - Each actor with the same DP rank processes the same data chunk in parallel.
For data collection:
- Data is collected only from the primary rank of each model/sequence parallel group.
- The primary rank is defined as the rank with (SP=0, TP=0, PP=0).
- The collected chunks are concatenated in order of DP rank to reconstruct the full data.
Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:
- Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk, and all actors with DP rank 1 process the second chunk.
- Data collection: Only two actors contribute to the final output - the primary rank from each DP group: (DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order.
Functions:
| Name | Description |
|---|---|
dispatch | |
async_collect | |
sync_collect | |
dispatch_from_staged | Dispatch to workers using pre-staged data from object store. |
validate_dispatch_args |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:101-227
class MeshDispatch(Dispatch):
"""Mesh dispatch type to dispatch data to a group of actors along the device mesh.
Supports DP (Data Parallel), SP (Sequence Parallel), TP (Tensor Parallel) and PP (Pipeline Parallel) parallelism.
The actor method should accept a single argument - the data batch.
For data dispatch:
* The input data is chunked into `dp_size` equal chunks, where `dp_size` is the size of data parallelism.
* Each actor with the same DP rank processes the same data chunk in parallel.
For data collection:
* Data is collected only from the primary rank of each model/sequence parallel group.
* The primary rank is defined as the rank with (SP=0, TP=0, PP=0).
* The collected chunks are concatenated in order of DP rank to reconstruct the full data.
Example: For a world size of 8, with DP size=2, SP size=2, TP size=2, PP size=1:
* Data dispatch: The data is chunked into 2 chunks. All actors with DP rank 0 process the first chunk,
and all actors with DP rank 1 process the second chunk.
* Data collection: Only two actors contribute to the final output - the primary rank from each DP group:
(DP=0, SP=0, TP=0, PP=0) and (DP=1, SP=0, TP=0, PP=0). Their chunks are concatenated in order.
"""
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
dp_size = actor_infos[0].rank.dp_size
assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
len(data), dp_size
)
chunk_size = len(data) // dp_size
data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)
# Put each unique chunk in object store ONCE to avoid redundant serialization
# when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]
for actor_info in actor_infos:
# Pass ObjectRef instead of data - workers will fetch from object store
chunk_ref = chunk_refs[actor_info.rank.dp]
object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
return object_refs
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
return
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = ray.get(object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
# all should be none
assert all(obj is None for obj in all_objects), "Got a mix of `None` and non-`None` objects"
return
@classmethod
def dispatch_from_staged(
cls, actor_infos: List[ActorInfo], method: str, data_ref: ObjectRef, start_idx: int, end_idx: int, **kwargs
) -> List[ObjectRef]:
"""
Dispatch to workers using pre-staged data from object store.
Workers receive the full batch ObjectRef and slice indices, then fetch
and slice locally. This avoids serialization on each mini-batch iteration.
Args:
actor_infos: List of actor info objects
method: Name of method to call on workers
data_ref: ObjectRef to full TrainingInputBatch in object store
start_idx: Start index for mini-batch slice (before DP chunking)
end_idx: End index for mini-batch slice (before DP chunking)
**kwargs: Additional keyword arguments to pass to the method
Returns:
List of ObjectRefs for worker results
"""
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
dp_size = actor_infos[0].rank.dp_size
# Compute per-DP-rank slice indices
mini_batch_size = end_idx - start_idx
assert (
mini_batch_size % dp_size == 0
), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}"
chunk_size = mini_batch_size // dp_size
for actor_info in actor_infos:
dp_rank = actor_info.rank.dp
# Compute this worker's slice of the mini-batch
worker_start = start_idx + dp_rank * chunk_size
worker_end = worker_start + chunk_size
object_refs.append(
getattr(actor_info.handle, method).remote(
data_ref, start_idx=worker_start, end_idx=worker_end, **kwargs
)
)
return object_refs
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# Extract data from either positional arg or kwarg
if args:
data = args[0]
remaining_kwargs = kwargs
elif "data" in kwargs:
data = kwargs.pop("data")
remaining_kwargs = kwargs
else:
raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")
if not isinstance(data, TrainingInputBatch):
raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
# Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
return (data,), remaining_kwargsattr dispatch
dispatch(actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs: TrainingInputBatch) -> List[ObjectRef]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:127-146
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, data: TrainingInputBatch, **kwargs) -> List[ObjectRef]:
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
dp_size = actor_infos[0].rank.dp_size
assert len(data) % dp_size == 0, "data batch size must be divisible by dp_size, got {} and {}".format(
len(data), dp_size
)
chunk_size = len(data) // dp_size
data_chunks: List[TrainingInputBatch] = data.chunk(chunk_size)
# Put each unique chunk in object store ONCE to avoid redundant serialization
# when the same chunk is sent to multiple workers (e.g., SP/TP replicas)
chunk_refs: List[ObjectRef] = [ray.put(chunk) for chunk in data_chunks]
for actor_info in actor_infos:
# Pass ObjectRef instead of data - workers will fetch from object store
chunk_ref = chunk_refs[actor_info.rank.dp]
object_refs.append(getattr(actor_info.handle, method).remote(chunk_ref, **kwargs))
return object_refsmethod abstractmethod async classmethod async_collect
async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:148-156
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
returnmethod abstractmethod classmethod sync_collect
sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:158-166
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
assert len(actor_infos) == len(object_refs), "`actor_infos` and `object_refs` must have the same length"
all_objects = ray.get(object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
# all should be none
assert all(obj is None for obj in all_objects), "Got a mix of `None` and non-`None` objects"
returnmethod classmethod dispatch_from_staged
dispatch_from_staged(actor_infos: List[ActorInfo], method: str, data_ref: ObjectRef, start_idx: int, end_idx: int, **kwargs: int) -> List[ObjectRef]Dispatch to workers using pre-staged data from object store.
Workers receive the full batch ObjectRef and slice indices, then fetch and slice locally. This avoids serialization on each mini-batch iteration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
actor_infos | List[ActorInfo] | List of actor info objects | required |
method | str | Name of method to call on workers | required |
data_ref | ObjectRef | ObjectRef to full TrainingInputBatch in object store | required |
start_idx | int | Start index for mini-batch slice (before DP chunking) | required |
end_idx | int | End index for mini-batch slice (before DP chunking) | required |
**kwargs | Additional keyword arguments to pass to the method | {} |
Returns:
| Type | Description |
|---|---|
| List[ObjectRef] | List of ObjectRefs for worker results |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:168-210
@classmethod
def dispatch_from_staged(
cls, actor_infos: List[ActorInfo], method: str, data_ref: ObjectRef, start_idx: int, end_idx: int, **kwargs
) -> List[ObjectRef]:
"""
Dispatch to workers using pre-staged data from object store.
Workers receive the full batch ObjectRef and slice indices, then fetch
and slice locally. This avoids serialization on each mini-batch iteration.
Args:
actor_infos: List of actor info objects
method: Name of method to call on workers
data_ref: ObjectRef to full TrainingInputBatch in object store
start_idx: Start index for mini-batch slice (before DP chunking)
end_idx: End index for mini-batch slice (before DP chunking)
**kwargs: Additional keyword arguments to pass to the method
Returns:
List of ObjectRefs for worker results
"""
assert len(actor_infos) > 0, "actor_infos must be a non-empty list"
object_refs = []
dp_size = actor_infos[0].rank.dp_size
# Compute per-DP-rank slice indices
mini_batch_size = end_idx - start_idx
assert (
mini_batch_size % dp_size == 0
), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}"
chunk_size = mini_batch_size // dp_size
for actor_info in actor_infos:
dp_rank = actor_info.rank.dp
# Compute this worker's slice of the mini-batch
worker_start = start_idx + dp_rank * chunk_size
worker_end = worker_start + chunk_size
object_refs.append(
getattr(actor_info.handle, method).remote(
data_ref, start_idx=worker_start, end_idx=worker_end, **kwargs
)
)
return object_refsmethod abstractmethod classmethod validate_dispatch_args
validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:212-227
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# Extract data from either positional arg or kwarg
if args:
data = args[0]
remaining_kwargs = kwargs
elif "data" in kwargs:
data = kwargs.pop("data")
remaining_kwargs = kwargs
else:
raise ValueError("MeshDispatch requires 'data' as first positional argument or keyword argument")
if not isinstance(data, TrainingInputBatch):
raise ValueError(f"For MeshDispatch, `data` entry should be a `TrainingInputBatch`, got {type(data)}")
# Pass through data as positional arg, and any other kwargs (e.g., loss_fn, loss_fn_config)
return (data,), remaining_kwargsclass PassThroughDispatch
Bases: Dispatch
PassThrough dispatch type to dispatch data to a group of actors without any sharding.
This is useful for cases where we want to run the same method on all the actors. Supports methods with any number of arguments.
Functions:
| Name | Description |
|---|---|
dispatch | |
async_collect | |
sync_collect | |
validate_dispatch_args |
Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:230-265
class PassThroughDispatch(Dispatch):
"""PassThrough dispatch type to dispatch data to a group of actors without any sharding.
This is useful for cases where we want to run the same method on all the actors.
Supports methods with any number of arguments.
"""
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
return
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
data_batches = ray.get(object_refs)
if len(data_batches) > 0 and data_batches[0] is not None:
assert isinstance(
data_batches[0], TrainingOutputBatch
), "data_batches must be a list of `TrainingOutputBatch` objects"
return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches)
# all should be none
assert all(obj is None for obj in data_batches), "Got a mix of `None` and non-`None` objects"
return
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# no validation needed just pass everything
return args, kwargsattr dispatch
dispatch(actor_infos: List[ActorInfo], method: str, *args: str, **kwargs: str) -> List[ObjectRef]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:237-239
@classmethod
def dispatch(cls, actor_infos: List[ActorInfo], method: str, *args, **kwargs) -> List[ObjectRef]:
return [getattr(actor_info.handle, method).remote(*args, **kwargs) for actor_info in actor_infos]method abstractmethod async classmethod async_collect
async_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:241-248
@classmethod
async def async_collect(
cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]
) -> Optional[TrainingOutputBatch]:
all_objects = await asyncio.gather(*object_refs)
if len(all_objects) and all_objects[0] is not None:
return concatenate_outputs_after_mesh_dispatch(actor_infos, all_objects)
returnmethod abstractmethod classmethod sync_collect
sync_collect(actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:250-260
@classmethod
def sync_collect(cls, actor_infos: List[ActorInfo], object_refs: List[ObjectRef]) -> Optional[TrainingOutputBatch]:
data_batches = ray.get(object_refs)
if len(data_batches) > 0 and data_batches[0] is not None:
assert isinstance(
data_batches[0], TrainingOutputBatch
), "data_batches must be a list of `TrainingOutputBatch` objects"
return concatenate_outputs_after_mesh_dispatch(actor_infos, data_batches)
# all should be none
assert all(obj is None for obj in data_batches), "Got a mix of `None` and non-`None` objects"
returnmethod abstractmethod classmethod validate_dispatch_args
validate_dispatch_args(*args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]Source code in skyrl/backends/skyrl_train/distributed/dispatch.py:262-265
@classmethod
def validate_dispatch_args(cls, *args, **kwargs) -> Tuple[Tuple, Dict[str, Any]]:
# no validation needed just pass everything
return args, kwargsWorker APIs
The base worker abstraction in SkyRL.
class Worker
Worker(cfg: TrainerConfig, *args: TrainerConfig, **kwargs: TrainerConfig)Bases: DistributedTorchRayActor
Functions:
| Name | Description |
|---|---|
get_node_local_rank | |
init_worker_process_group | |
get_mesh_rank | |
get_gpu_id | |
get_ray_node_id | |
get_master_addr_port | |
init_model | Initialize worker state (model, and optimizer if applicable) on worker. |
empty_cache | Empty GPU memory cache on Worker's CUDA device |
offload_to_cpu | Offload all worker state to CPU. |
backload_to_gpu | Backload worker state to GPU |
get_cuda_memory | Get CUDA memory usage on worker's CUDA device. |
save_memory_snapshot | Save a snapshot of memory usage on the Worker's CUDA device. |
init_weight_sync_state | Initialize state for weight syncing with Inference Engine Client |
forward | Run forward pass on the input batch in inference mode. |
Attributes:
| Name | Type | Description |
|---|---|---|
sequence_parallel_size | int | |
record_memory | ||
cfg |
Source code in skyrl/backends/skyrl_train/workers/worker.py:231-395
class Worker(DistributedTorchRayActor):
def __init__(self, cfg: TrainerConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cfg = cfg
self._transfer_strategy_cls = None # Set in init_weight_transfer_communicator
if self.cfg.algorithm.temperature is None:
raise ValueError("`cfg.algorithm.temperature` must be set")
def init_model(self, *args, **kwargs):
"""Initialize worker state (model, and optimizer if applicable) on worker."""
raise NotImplementedError()
def empty_cache(self) -> None:
"""Empty GPU memory cache on Worker's CUDA device"""
torch.cuda.empty_cache()
def offload_to_cpu(self, pin_memory=True, non_blocking=True):
"""Offload all worker state to CPU.
After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain
Args:
pin_memory: Whether to use pinned/ paged-locked memory on CPU
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()
def backload_to_gpu(self, non_blocking=True):
"""Backload worker state to GPU
Args:
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()
def get_cuda_memory(self) -> Dict[str, Any]:
"""Get CUDA memory usage on worker's CUDA device."""
torch.cuda.synchronize()
free, total = torch.cuda.mem_get_info()
return {
"allocated": torch.cuda.memory_allocated(),
"reserved": torch.cuda.memory_reserved(),
"free": free,
"total": total,
}
def save_memory_snapshot(self, tag: str = ""):
"""Save a snapshot of memory usage on the Worker's CUDA device.
No-ops if record_memory is False.
Args:
tag: Label for the snapshot (e.g., "forward_backward", "optim_step")
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
if not self.record_memory:
return
# Track snapshot count for unique filenames
if not hasattr(self, "_snapshot_count"):
self._snapshot_count = 0
self._snapshot_count += 1
rank = torch.distributed.get_rank()
save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
if self._local_rank == 0 and not io.exists(save_path):
io.makedirs(save_path, exist_ok=True)
torch.distributed.barrier()
tag_str = f"_{tag}" if tag else ""
file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
record_memory_path = os.path.join(save_path, file_name)
if io.exists(record_memory_path):
# seeing issues if we don't remove the file first
io.remove(record_memory_path)
torch.cuda.memory._dump_snapshot(record_memory_path)
async def init_weight_sync_state(
self,
inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
inference_engine_cfg: "InferenceEngineConfig",
):
"""Initialize state for weight syncing with Inference Engine Client
Creates init info and sender, then sends init info to inference engines
so they can create receivers.
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls
assert inference_engine_client is not None
# Determine transfer strategy based on inference engine config and placement
self._transfer_strategy_cls = get_transfer_strategy_cls(
weight_sync_backend=inference_engine_cfg.weight_sync_backend,
colocate_all=self.cfg.placement.colocate_all,
)
# For new inference path, fetch world_size from servers
# For legacy path, calculate from config
inference_world_size = None
if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
inference_world_size, _ = await inference_engine_client.get_world_size()
# Create init info on all ranks (it's deterministic from cfg or fetched world_size)
init_info = self._transfer_strategy_cls.create_init_info(
inference_engine_cfg, inference_world_size=inference_world_size
)
# Create sender on all ranks
# Strategy implementations may have different logic for different ranks
tasks = [
asyncio.to_thread(
self._transfer_strategy_cls.create_sender,
init_info=init_info,
inference_client=inference_engine_client,
),
]
# Only rank 0 initializes receivers on inference engines
# NOTE: For broadcast strategy, sender and receiver init must run concurrently
# because both need to join the same process group to avoid deadlock
if torch.distributed.get_rank() == 0:
tasks.append(inference_engine_client.init_weight_update_communicator(init_info))
results = await asyncio.gather(*tasks)
self._weight_transfer_sender = results[0] # sender is always first task
# # Register signal handlers for termination only on rank 0
# NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
# The better way is to just have this specified in __del__, but there is
# no guarattee that __del__ will be called in general. Ray also doesn't
# explictly call __del__ when the actor shuts down.
# It's commented out so that we can fix this in the future.
# atexit.register(self._handle_termination)
torch.distributed.barrier()
def forward(
self,
data: TrainingInputBatch,
) -> TrainingOutputBatch:
"""Run forward pass on the input batch in inference mode.
This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.micro_forward_batch_size_per_gpu`.
"""
# run in micro batches of cfg.micro_forward_batch_size_per_gpu
# TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)
outputs = []
for micro_batch in micro_batches:
outputs.append(self._forward_micro_batch(micro_batch))
output = TrainingOutputBatch.cat(outputs)
if output.device is not None and output.device != torch.device("cpu"):
output = output.to("cpu")
return output
def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutputBatch:
raise NotImplementedError()attr sequence_parallel_size
sequence_parallel_size: int = sequence_parallel_sizeattr record_memory
record_memory = record_memoryget_node_local_rank
get_node_local_rank()init_worker_process_group
init_worker_process_group()get_mesh_rank
get_mesh_rank()get_gpu_id
get_gpu_id()get_ray_node_id
get_ray_node_id()get_master_addr_port
get_master_addr_port()attr cfg
cfg = cfgmethod init_model
init_model(*args, **kwargs)Initialize worker state (model, and optimizer if applicable) on worker.
Source code in skyrl/backends/skyrl_train/workers/worker.py:240-242
def init_model(self, *args, **kwargs):
"""Initialize worker state (model, and optimizer if applicable) on worker."""
raise NotImplementedError()method empty_cache
empty_cache() -> NoneEmpty GPU memory cache on Worker's CUDA device
Source code in skyrl/backends/skyrl_train/workers/worker.py:244-246
def empty_cache(self) -> None:
"""Empty GPU memory cache on Worker's CUDA device"""
torch.cuda.empty_cache()method offload_to_cpu
offload_to_cpu(pin_memory = True, non_blocking = True)Offload all worker state to CPU.
After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pin_memory | Whether to use pinned/ paged-locked memory on CPU | True | |
non_blocking | Whether the operation is non-blocking | True |
Source code in skyrl/backends/skyrl_train/workers/worker.py:248-257
def offload_to_cpu(self, pin_memory=True, non_blocking=True):
"""Offload all worker state to CPU.
After this function runs, only temporary reserved memory and torch's pre-loaded cuda kernels (~ GB) will remain
Args:
pin_memory: Whether to use pinned/ paged-locked memory on CPU
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()method backload_to_gpu
backload_to_gpu(non_blocking = True)Backload worker state to GPU
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
non_blocking | Whether the operation is non-blocking | True |
Source code in skyrl/backends/skyrl_train/workers/worker.py:259-265
def backload_to_gpu(self, non_blocking=True):
"""Backload worker state to GPU
Args:
non_blocking: Whether the operation is non-blocking
"""
raise NotImplementedError()method get_cuda_memory
get_cuda_memory() -> Dict[str, Any]Get CUDA memory usage on worker's CUDA device.
Source code in skyrl/backends/skyrl_train/workers/worker.py:267-276
def get_cuda_memory(self) -> Dict[str, Any]:
"""Get CUDA memory usage on worker's CUDA device."""
torch.cuda.synchronize()
free, total = torch.cuda.mem_get_info()
return {
"allocated": torch.cuda.memory_allocated(),
"reserved": torch.cuda.memory_reserved(),
"free": free,
"total": total,
}method save_memory_snapshot
save_memory_snapshot(tag: str = '')Save a snapshot of memory usage on the Worker's CUDA device.
No-ops if record_memory is False.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tag | str | Label for the snapshot (e.g., "forward_backward", "optim_step") | '' |
.. note:: This function should be called on all the ranks in the worker group simultaneously.
Source code in skyrl/backends/skyrl_train/workers/worker.py:278-309
def save_memory_snapshot(self, tag: str = ""):
"""Save a snapshot of memory usage on the Worker's CUDA device.
No-ops if record_memory is False.
Args:
tag: Label for the snapshot (e.g., "forward_backward", "optim_step")
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
if not self.record_memory:
return
# Track snapshot count for unique filenames
if not hasattr(self, "_snapshot_count"):
self._snapshot_count = 0
self._snapshot_count += 1
rank = torch.distributed.get_rank()
save_path = os.path.join(self.cfg.ckpt_path, "memory_snapshots")
if self._local_rank == 0 and not io.exists(save_path):
io.makedirs(save_path, exist_ok=True)
torch.distributed.barrier()
tag_str = f"_{tag}" if tag else ""
file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle"
record_memory_path = os.path.join(save_path, file_name)
if io.exists(record_memory_path):
# seeing issues if we don't remove the file first
io.remove(record_memory_path)
torch.cuda.memory._dump_snapshot(record_memory_path)method init_weight_sync_state
init_weight_sync_state(inference_engine_client: Union[InferenceEngineClient, RemoteInferenceClient], inference_engine_cfg: InferenceEngineConfig)Initialize state for weight syncing with Inference Engine Client
Creates init info and sender, then sends init info to inference engines so they can create receivers.
.. note:: This function should be called on all the ranks in the worker group simultaneously.
Source code in skyrl/backends/skyrl_train/workers/worker.py:311-372
async def init_weight_sync_state(
self,
inference_engine_client: "Union[InferenceEngineClient, RemoteInferenceClient]",
inference_engine_cfg: "InferenceEngineConfig",
):
"""Initialize state for weight syncing with Inference Engine Client
Creates init info and sender, then sends init info to inference engines
so they can create receivers.
.. note::
This function should be called on all the ranks in the worker group simultaneously.
"""
from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy_cls
assert inference_engine_client is not None
# Determine transfer strategy based on inference engine config and placement
self._transfer_strategy_cls = get_transfer_strategy_cls(
weight_sync_backend=inference_engine_cfg.weight_sync_backend,
colocate_all=self.cfg.placement.colocate_all,
)
# For new inference path, fetch world_size from servers
# For legacy path, calculate from config
inference_world_size = None
if _SKYRL_USE_NEW_INFERENCE and hasattr(inference_engine_client, "get_world_size"):
inference_world_size, _ = await inference_engine_client.get_world_size()
# Create init info on all ranks (it's deterministic from cfg or fetched world_size)
init_info = self._transfer_strategy_cls.create_init_info(
inference_engine_cfg, inference_world_size=inference_world_size
)
# Create sender on all ranks
# Strategy implementations may have different logic for different ranks
tasks = [
asyncio.to_thread(
self._transfer_strategy_cls.create_sender,
init_info=init_info,
inference_client=inference_engine_client,
),
]
# Only rank 0 initializes receivers on inference engines
# NOTE: For broadcast strategy, sender and receiver init must run concurrently
# because both need to join the same process group to avoid deadlock
if torch.distributed.get_rank() == 0:
tasks.append(inference_engine_client.init_weight_update_communicator(init_info))
results = await asyncio.gather(*tasks)
self._weight_transfer_sender = results[0] # sender is always first task
# # Register signal handlers for termination only on rank 0
# NOTE (sumanthrh): This doesn't work yet, and is thus commented out.
# The better way is to just have this specified in __del__, but there is
# no guarattee that __del__ will be called in general. Ray also doesn't
# explictly call __del__ when the actor shuts down.
# It's commented out so that we can fix this in the future.
# atexit.register(self._handle_termination)
torch.distributed.barrier()method abstractmethod forward
forward(data: TrainingInputBatch) -> TrainingOutputBatchRun forward pass on the input batch in inference mode.
This is a wrapper around _forward_micro_batch that runs in micro batches of cfg.micro_forward_batch_size_per_gpu.
Source code in skyrl/backends/skyrl_train/workers/worker.py:374-392
def forward(
self,
data: TrainingInputBatch,
) -> TrainingOutputBatch:
"""Run forward pass on the input batch in inference mode.
This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.micro_forward_batch_size_per_gpu`.
"""
# run in micro batches of cfg.micro_forward_batch_size_per_gpu
# TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu)
outputs = []
for micro_batch in micro_batches:
outputs.append(self._forward_micro_batch(micro_batch))
output = TrainingOutputBatch.cat(outputs)
if output.device is not None and output.device != torch.device("cpu"):
output = output.to("cpu")
return outputclass PPORayActorGroup
PPORayActorGroup(cfg: TrainerConfig, num_nodes: TrainerConfig, num_gpus_per_node: TrainerConfig, ray_actor_type: Type[Worker], pg: Optional[PlacementGroup] = None, num_gpus_per_actor: float = 1.0, resources: Optional[Dict[str, float]] = None, num_resources_per_node: Optional[int] = None, colocate_all: bool = False, sequence_parallel_size: int = 1, record_memory: bool = False) -> NoneA group of ray actors Functions start with 'async' should return list of object refs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg | TrainerConfig | config object for workers | required |
num_nodes | int | Number of nodes for this actor group. | required |
num_gpus_per_node | int | Number of gpus for this actor group. | required |
ray_actor_type | Type[Worker] | PPO model type that this actor group serve on. | required |
pg | Optional[PlacementGroup] | Placement group to schedule actor on. If none, create new placement group automatically. Defaults to None. | None |
num_gpus_per_actor | float | Number of gpus allocated for each actor. If < 1.0, multiple models can share same gpu. Defaults to 1. | 1.0 |
Functions:
| Name | Description |
|---|---|
async_init_model | Asynchronously initialize worker state (model, and optimizer if applicable) from model path |
offload_to_cpu | Offload all worker state to CPU. |
backload_to_gpu | Backload worker state to GPU |
run_method | Run a method on all actors using specified dispatch type synchronously. |
async_run_ray_method | Run a method on all actors using specified dispatch type asynchronously. |
async_run_method | Run a method on all actors using specified dispatch type in an asyncio-compatible way. |
Attributes:
| Name | Type | Description |
|---|---|---|
cfg | ||
ray_actor_type | ||
colocate_all | ||
sequence_parallel_size | ||
record_memory |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pg | Optional[PlacementGroup] | Placement group for the worker group. Accepts a single PlacementGroup, or None. Note that if colocate_all is True, the number of bundles in the placement group must match world_size. | None |
Source code in skyrl/backends/skyrl_train/workers/worker.py:399-663
class PPORayActorGroup:
"""
A group of ray actors
Functions start with 'async' should return list of object refs
Args:
cfg: config object for workers
num_nodes (int): Number of nodes for this actor group.
num_gpus_per_node (int): Number of gpus for this actor group.
ray_actor_type (Type[Worker]): PPO model type that this actor group serve on.
pg (Optional[PlacementGroup]): Placement group to schedule actor on.
If none, create new placement group automatically. Defaults to None.
num_gpus_per_actor (float, optional): Number of gpus allocated for each actor.
If < 1.0, multiple models can share same gpu. Defaults to 1.
"""
def __init__(
self,
cfg: TrainerConfig,
num_nodes,
num_gpus_per_node,
ray_actor_type: Type[Worker],
pg: Optional[PlacementGroup] = None,
num_gpus_per_actor: float = 1.0,
resources: Optional[Dict[str, float]] = None,
num_resources_per_node: Optional[int] = None,
colocate_all: bool = False,
sequence_parallel_size: int = 1,
record_memory: bool = False,
) -> None:
"""
Args:
pg: Placement group for the worker group. Accepts a single PlacementGroup, or None.
Note that if colocate_all is True, the number of bundles in the placement group must match world_size.
"""
self.cfg = cfg
self._num_nodes = num_nodes
self._num_gpus_per_node = num_gpus_per_node
self.ray_actor_type = ray_actor_type
# custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html
self._resources = resources
self._num_resources_per_node = num_resources_per_node
self.colocate_all = colocate_all
self.sequence_parallel_size = sequence_parallel_size
self.record_memory = record_memory
self._initiate_actors(pg, num_gpus_per_actor)
def _initiate_actors(self, pg: Optional[PlacementGroup], num_gpus_per_actor: float):
"""Initialize Ray actors in the worker group.
Args:
pg: A single placement group for the worker group, or None.
num_gpus_per_actor: The number of gpus to allocate per actor.
"""
world_size = self._num_nodes * self._num_gpus_per_node
if self.colocate_all:
assert (
pg is not None
), "if colocate_all is True, the shared placement group must be provided to PPORayActorGroup"
pg_data = placement_group_table(pg)
assert len(pg_data["bundles"]) == world_size, (
f"if colocate_all is True, the number of bundles in the placement group "
f"must match world_size. Got {len(pg_data['bundles'])} bundles but world_size={world_size}"
)
# Build rank → bundle_index assignments sorted by (node_id, gpu_id)
# for deterministic ordering.
reordered_bundle_indices = []
if pg is not None:
pg_data = placement_group_table(pg)
if len(pg_data["bundles"]) == world_size:
reordered_bundle_indices = get_reordered_bundle_indices(pg)
# If no PG provided, create one internally
if pg is None and self._num_gpus_per_node > 1:
bundles = [{"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes)]
if self._resources:
resources_name = list(self._resources.keys())[0]
for i in range(len(bundles)):
bundles[i][resources_name] = self._num_resources_per_node
internal_pg = placement_group(bundles, strategy="PACK")
get_ray_pg_ready_with_timeout(internal_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
pg = internal_pg
def _scheduling_strategy_for_rank(rank):
if reordered_bundle_indices:
return PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=reordered_bundle_indices[rank],
)
elif pg is not None:
return PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=rank // self._num_gpus_per_node,
)
# else we are in the single gpu case per node case in which case we don't need to set
# bundle indices
return None
sched = _scheduling_strategy_for_rank(0)
actor_options = {
"num_cpus": num_gpus_per_actor,
"num_gpus": num_gpus_per_actor,
"resources": self._resources,
}
if sched is not None:
actor_options["scheduling_strategy"] = sched
master_actor = self.ray_actor_type.options(**actor_options).remote(
cfg=self.cfg,
world_size=world_size,
rank=0,
local_rank=0,
master_addr=None,
master_port=None,
sequence_parallel_size=self.sequence_parallel_size,
record_memory=self.record_memory,
)
self._actor_handlers = [master_actor]
if world_size > 1:
master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote())
for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node
sched = _scheduling_strategy_for_rank(rank)
actor_options = {
"num_cpus": num_gpus_per_actor,
"num_gpus": num_gpus_per_actor,
"resources": self._resources,
}
if sched is not None:
actor_options["scheduling_strategy"] = sched
worker_actor = self.ray_actor_type.options(**actor_options).remote(
cfg=self.cfg,
world_size=world_size,
rank=rank,
local_rank=local_rank,
master_addr=master_addr,
master_port=master_port,
sequence_parallel_size=self.sequence_parallel_size,
record_memory=self.record_memory,
)
self._actor_handlers.append(worker_actor)
# Initialize process group
logger.info("Initializing process group for RayActorGroup")
ray.get([actor.init_worker_process_group.remote() for actor in self._actor_handlers])
logger.info("Initialized process group for RayActorGroup")
self.actor_infos = [ActorInfo(actor, ray.get(actor.get_mesh_rank.remote())) for actor in self._actor_handlers]
logger.info(f"Mesh Ranks: {[actor_info.rank for actor_info in self.actor_infos]}")
def async_init_model(
self,
*args,
**kwargs,
) -> List[ObjectRef]:
"""Asynchronously initialize worker state (model, and optimizer if applicable) from model path
on all the workers.
Returns:
A list of ray object refs.
"""
return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]
def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
"""Offload all worker state to CPU.
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of object refs.
"""
refs = [
actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)
def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
"""Backload worker state to GPU
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of ObjectRefs.
"""
refs = [
actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)
def run_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type synchronously.
The method should either return `None` or a `TrainingOutputBatch` object.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
Collect results from all the actors.
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
# Collect results from all the actors
ret = dispatch_class.sync_collect(self.actor_infos, object_refs)
return ret
def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
"""Run a method on all actors using specified dispatch type asynchronously.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
List of object references
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return object_refs
async def async_run_method(
self, dispatch_type: str, method_name: str, *args, **kwargs
) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type in an asyncio-compatible way.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
TrainingOutputBatch: concatenated results from all actors
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return await dispatch_class.async_collect(self.actor_infos, object_refs)attr cfg
cfg = cfgattr ray_actor_type
ray_actor_type = ray_actor_typeattr colocate_all
colocate_all = colocate_allattr sequence_parallel_size
sequence_parallel_size = sequence_parallel_sizeattr record_memory
record_memory = record_memorymethod async_init_model
async_init_model(*args, **kwargs) -> List[ObjectRef]Asynchronously initialize worker state (model, and optimizer if applicable) from model path on all the workers.
Returns:
| Type | Description |
|---|---|
| List[ObjectRef] | A list of ray object refs. |
Source code in skyrl/backends/skyrl_train/workers/worker.py:556-567
def async_init_model(
self,
*args,
**kwargs,
) -> List[ObjectRef]:
"""Asynchronously initialize worker state (model, and optimizer if applicable) from model path
on all the workers.
Returns:
A list of ray object refs.
"""
return [actor.init_model.remote(*args, **kwargs) for actor in self._actor_handlers]method offload_to_cpu
offload_to_cpu(nonblocking = False, offload_optimizer = True, offload_model = True)Offload all worker state to CPU.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nonblocking | Whether this operation is synchronous or asynchronous. | False |
Source code in skyrl/backends/skyrl_train/workers/worker.py:569-582
def offload_to_cpu(self, nonblocking=False, offload_optimizer=True, offload_model=True):
"""Offload all worker state to CPU.
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of object refs.
"""
refs = [
actor.offload_to_cpu.remote(offload_optimizer=offload_optimizer, offload_model=offload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)method backload_to_gpu
backload_to_gpu(nonblocking = False, backload_optimizer = True, backload_model = True)Backload worker state to GPU
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nonblocking | Whether this operation is synchronous or asynchronous. | False |
Source code in skyrl/backends/skyrl_train/workers/worker.py:584-597
def backload_to_gpu(self, nonblocking=False, backload_optimizer=True, backload_model=True):
"""Backload worker state to GPU
Args:
nonblocking: Whether this operation is synchronous or asynchronous.
If `nonblocking=True`, then the function returns a list of ObjectRefs.
"""
refs = [
actor.backload_to_gpu.remote(backload_optimizer=backload_optimizer, backload_model=backload_model)
for actor in self._actor_handlers
]
if nonblocking:
return refs
return ray.get(refs)method run_method
run_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> Optional[TrainingOutputBatch]Run a method on all actors using specified dispatch type synchronously.
The method should either return None or a TrainingOutputBatch object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dispatch_type | str | Type of dispatch to use ("mesh" or "pass_through") | required |
method_name | str | Name of the method to call on actors | required |
*args | Positional arguments to pass to the method | () | |
**kwargs | Keyword arguments to pass to the method | {} |
Returns:
| Type | Description |
|---|---|
| Optional[TrainingOutputBatch] | Collect results from all the actors. |
Source code in skyrl/backends/skyrl_train/workers/worker.py:599-621
def run_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type synchronously.
The method should either return `None` or a `TrainingOutputBatch` object.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
Collect results from all the actors.
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
# Collect results from all the actors
ret = dispatch_class.sync_collect(self.actor_infos, object_refs)
return retmethod async_run_ray_method
async_run_ray_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> List[ObjectRef]Run a method on all actors using specified dispatch type asynchronously.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dispatch_type | str | Type of dispatch to use ("mesh" or "pass_through") | required |
method_name | str | Name of the method to call on actors | required |
*args | Positional arguments to pass to the method | () | |
**kwargs | Keyword arguments to pass to the method | {} |
Returns:
| Type | Description |
|---|---|
| List[ObjectRef] | List of object references |
Source code in skyrl/backends/skyrl_train/workers/worker.py:623-641
def async_run_ray_method(self, dispatch_type: str, method_name: str, *args, **kwargs) -> List[ObjectRef]:
"""Run a method on all actors using specified dispatch type asynchronously.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
List of object references
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return object_refsmethod async async_run_method
async_run_method(dispatch_type: str, method_name: str, *args: str, **kwargs: str) -> Optional[TrainingOutputBatch]Run a method on all actors using specified dispatch type in an asyncio-compatible way.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dispatch_type | str | Type of dispatch to use ("mesh" or "pass_through") | required |
method_name | str | Name of the method to call on actors | required |
*args | Positional arguments to pass to the method | () | |
**kwargs | Keyword arguments to pass to the method | {} |
Returns:
| Name | Type | Description |
|---|---|---|
TrainingOutputBatch | Optional[TrainingOutputBatch] | concatenated results from all actors |
Source code in skyrl/backends/skyrl_train/workers/worker.py:643-663
async def async_run_method(
self, dispatch_type: str, method_name: str, *args, **kwargs
) -> Optional[TrainingOutputBatch]:
"""Run a method on all actors using specified dispatch type in an asyncio-compatible way.
Args:
dispatch_type: Type of dispatch to use ("mesh" or "pass_through")
method_name: Name of the method to call on actors
*args: Positional arguments to pass to the method
**kwargs: Keyword arguments to pass to the method
Returns:
TrainingOutputBatch: concatenated results from all actors
"""
dispatch_class: Dispatch = DispatchRegistry.get(dispatch_type)
# validate the dispatch args to be sent to `.dispatch`
args, kwargs = dispatch_class.validate_dispatch_args(*args, **kwargs)
# Dispatch the method call
object_refs = dispatch_class.dispatch(self.actor_infos, method_name, *args, **kwargs)
return await dispatch_class.async_collect(self.actor_infos, object_refs)