Step-Wise Training
As described in Agent Integration, there are multiple ways to integrate a custom agent with SkyRL. The simplest — re-tokenization — works out of the box for many agent harnesses and has been used successfully in open-source recipes.
However, re-tokenization has two fundamental limitations:
-
Re-tokenization drift. When the full conversation string is re-tokenized after generation, the resulting token IDs can differ from what the model actually generated. Causes include non-unique BPE tokenization (e.g.
"HAVING"→H+AVINGvsHAV+ING), tool-call serialization changes, and chat template differences at turn boundaries. While this is acceptable for basic synchronous training, it becomes a real problem when you want rollout correction (e.g. TIS, truncated importance sampling) — which is crucial for fully async RL. TIS computes importance ratiosπ_current(token) / π_rollout(token), and if the training tokens differ from the generation tokens, the recordedrollout_logprobsno longer correspond to the actual tokens being trained on, making the ratios meaningless. -
Context management. Many agent harnesses perform operations that make the chat history non-strictly-appending — for example, stripping thinking tokens between turns, summarizing long contexts, or resetting the conversation window. Re-tokenization assumes a single linear conversation, so it cannot represent these discontinuities. Note that token-in-token-out (approach 2 in Agent Integration) also requires a strictly appending token sequence on its own, but it can be combined with step-wise training to handle context management.
Step-wise training addresses both problems. Instead of producing one (prompt, response) pair per trajectory, it decomposes each multi-turn trajectory into N separate training samples (one per LLM turn), using the exact token IDs and logprobs from the inference engine (via vLLM's return_token_ids). Each step's prompt is the full context the model saw at that turn, and the response is exactly the tokens the model generated. Because each turn is an independent sample, context management operations between turns are naturally supported — there is no requirement that turn N+1's prompt be a prefix extension of turn N's full sequence.
Quick start
To see how SkyRLGymGenerator supports step-wise training, you can run it with the search-r1 example.
USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.shThis page will also guide you how to implement step-wise training for your custom generator.
Impact on Training
When step-wise is enabled, a batch of T trajectories with an average of M turns per trajectory produces T×M training samples (sequences). This means:
- Each mini-batch contains the sequences for exactly
policy_mini_batch_sizeprompts, regardless of how many turns those prompts produced. This means the number of mini-batches (and hence optimizer steps) per training batch is alwaystrain_batch_size / policy_mini_batch_size, independent of the number of turns. This also means that the actual mini batch size (number of sequences) trained in each mini batch can vary. Each mini batch always leads to a single optimizer step. - Advantages are computed on last steps only, then broadcast to all steps of the same trajectory. This is mathematically equivalent to non-step-wise advantage computation for GRPO.
- Training time grows as O(T²) vs O(T), since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL supports prefix-aware merging of per-step sequences when the prefix matches with config flag
generator.merge_stepwise_output, which can reduce the O(T²) cost if chat history is linearly appending across turns and there is no token mismatch. See https://github.com/NovaSky-AI/SkyRL/pull/1532 - Metrics like
generate/avg_num_tokensandgenerate/avg_response_lengthare per-turn rather than per-trajectory, since each training sample is a single turn.
Some algorithms have their behavior altered by step-wise decomposition, since each turn is now treated as its own sequence:
- GSPO loss, which computes a sequence-level importance weight — under step-wise training, it operates over one turn rather than the entire trajectory.
- Off-policy rollout correction besides token-level TIS (
trainer.algorithm.off_policy_correction.tis_ratio_type="token") — sequence-level corrections aggregate over a different scope. - Loss reduction methods like
sequence_meanandseq_mean_token_sum_norm— trajectories with more turns contribute proportionally more to the loss.
That said, some research suggests that treating each turn as a separate sequence may actually be beneficial. See the section on Modelling Multi-Turn Agentic Task as Chunked MDP.
Configuration
Enable step-wise training by setting:
generator.step_wise_trajectories=trueThis flag is defined in GeneratorConfig (skyrl/train/config/config.py):
@dataclass
class GeneratorConfig(BaseConfig):
step_wise_trajectories: bool = FalseGeneratorOutput Format
Normally, each element in GeneratorOutput (i.e. response_ids[i], prompt_token_ids[i], rewards[i], etc.) represents a single trajectory. With step-wise training, each element instead represents a single step (one LLM turn within a trajectory). A trajectory with 3 turns produces 3 elements rather than 1.
The GeneratorOutput TypedDict is defined in skyrl/train/generators/base.py:
class GeneratorOutput(TypedDict):
prompt_token_ids: List[List[int]]
response_ids: List[List[int]]
rewards: Union[List[float], List[List[float]]]
loss_masks: List[List[int]]
stop_reasons: Optional[List[str]]
rollout_metrics: Optional[Dict[str, Any]]
rollout_logprobs: Optional[List[List[float]]]
trajectory_ids: Optional[List[TrajectoryID]]
rollout_expert_indices: Optional[List[List[List[List[int]]]]]
# Applicable only for step-wise training
is_last_step: Optional[List[bool]]Step-Wise Fields
When step_wise_trajectories=True, some related fields:
| Field | Type | Description |
|---|---|---|
is_last_step | List[bool] | Marks the final step of each trajectory. Must have at least one True, and the last element must be True. |
trajectory_ids | List[TrajectoryID] | Associates each step-sample with its parent trajectory. All steps of the same trajectory share the same TrajectoryID. |
rollout_logprobs | List[List[float]] | Per-token logprobs from the inference engine, aligned with response_ids. Required for TIS. |
Concrete Example
Consider 2 trajectories: trajectory A has 3 turns, trajectory B has 2 turns.
GeneratorOutput(
prompt_token_ids=[
[tok_A_prompt_turn1], # A, step 0: initial prompt
[tok_A_prompt_turn2], # A, step 1: prompt + turn1 history
[tok_A_prompt_turn3], # A, step 2: prompt + turn1+2 history
[tok_B_prompt_turn1], # B, step 0: initial prompt
[tok_B_prompt_turn2], # B, step 1: prompt + turn1 history
],
response_ids=[
[tok_A_resp_turn1], # exact tokens generated by model at turn 1
[tok_A_resp_turn2], # exact tokens generated by model at turn 2
[tok_A_resp_turn3], # exact tokens generated by model at turn 3
[tok_B_resp_turn1],
[tok_B_resp_turn2],
],
rewards=[
[0.0, 0.0, ..., 0.0], # A step 0: all zeros (intermediate)
[0.0, 0.0, ..., 0.0], # A step 1: all zeros (intermediate)
[0.0, 0.0, ..., 1.0], # A step 2: reward at last token of last step
[0.0, 0.0, ..., 0.0], # B step 0: all zeros (intermediate)
[0.0, 0.0, ..., 0.5], # B step 1: reward at last token of last step
],
loss_masks=[
[1, 1, ..., 1], # all 1s: every response token is trainable
[1, 1, ..., 1], # (no interleaved obs tokens in step-wise)
[1, 1, ..., 1],
[1, 1, ..., 1],
[1, 1, ..., 1],
],
rollout_logprobs=[
[-1.2, -0.8, ..., -2.1], # exact logprobs from inference engine
[-0.5, -1.1, ..., -0.9],
[-1.0, -0.3, ..., -1.5],
[-0.7, -1.4, ..., -0.6],
[-1.3, -0.2, ..., -1.8],
],
is_last_step=[False, False, True, False, True],
trajectory_ids=[tid_A, tid_A, tid_A, tid_B, tid_B],
stop_reasons=["tool_call", "tool_call", "stop", "tool_call", "stop"],
rollout_metrics={...},
)Key Invariants
The following are validated by _validate_step_wise_fields() in skyrl/train/utils/trainer_utils.py:
is_last_stepandtrajectory_idsmust be present and non-None.- Lengths must match
response_ids. Every list-type field has one entry per step-sample. is_last_step[-1]must beTrue. The last sample in the batch must be the final step of its trajectory.- Contiguous ordering. All steps of the same trajectory must be adjacent — no interleaving. This is critical because the trainer's advantage broadcast uses
cumsum(shifted_is_last_step)to map steps to trajectories, which silently produces wrong results if steps are interleaved. - Boundary alignment.
is_last_step[i]must beTruewherevertrajectory_idschanges (i.e., at every trajectory boundary).
Implementing Step-Wise for Custom Generators
If you are implementing a custom generator that supports step-wise training:
1. Collect Exact Token IDs and Logprobs (if using TIS)
Use vLLM's return_token_ids (via extra_body in LiteLLM or directly) to get the exact token IDs for both the prompt and completion at each turn. Do not re-tokenize from strings — this is the whole point of step-wise training.
# Example: requesting token IDs via LiteLLM
response = await litellm.acompletion(
model="hosted_vllm/your-model",
messages=messages,
extra_body={
"return_token_ids": True,
"logprobs": True,
},
)
# Access: response.choices[0].token_ids, response.choices[0].prompt_token_ids, response.choices[0].logprobs2. Set Loss Masks
Set loss_mask = [1] * len(response_ids[i]) for each step. Since each step's response contains only the model's completion tokens (no interleaved observations), all tokens are trainable.
3. Assign Rewards
Only the last step of each trajectory receives the actual reward. Intermediate steps get all zeros:
for i, step_output in enumerate(trajectory_steps):
if i == len(trajectory_steps) - 1:
# Last step: reward at last token position
rewards = [0.0] * (len(step_output.response_ids) - 1) + [trajectory_reward]
else:
# Intermediate step: all zeros
rewards = [0.0] * len(step_output.response_ids)Note that SkyRL currently only supports trajectory-level reward for step-wise training. Therefore, the reward should be placed at the last step's last token, and all non-last-step rewards are ignored. We then use the last step's reward to estimate the advantage and broadcast it to previous turns. Because of this, you should only use outcome-based advantage estimators (cfg.trainer.algorithm.advantage_estimator in grpo, rloo, or maxrl); reinforce++ and gae are rejected at config validation time.
4. Ensure Contiguous Ordering
All steps of trajectory A must appear before any steps of trajectory B in the output lists:
# Correct: [A_step0, A_step1, A_step2, B_step0, B_step1]
# INCORRECT: [A_step0, B_step0, A_step1, B_step1, A_step2]5. Mark Trajectory Boundaries
Set is_last_step[i] = True for the final step of each trajectory, False for all others.