SFT
Supervised Fine-Tuning configuration and trainer.
Configuration
class SFTPlacementConfig
SFTPlacementConfig(num_nodes: int = 1, num_gpus_per_node: int = 4) -> NoneBases: BaseConfig
Placement configuration for SFT training
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
num_nodes | int | |
num_gpus_per_node | int |
Source code in skyrl/train/config/sft_config.py:48-53
@dataclass
class SFTPlacementConfig(BaseConfig):
"""Placement configuration for SFT training"""
num_nodes: int = 1
num_gpus_per_node: int = 4from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr num_nodes
num_nodes: int = 1attr num_gpus_per_node
num_gpus_per_node: int = 4class SFTConfig
SFTConfig(model: ModelConfig = (lambda: ModelConfig(path='Qwen/Qwen3-0.6B'))(), optimizer_config: OptimizerConfig = OptimizerConfig(), placement: SFTPlacementConfig = SFTPlacementConfig(), megatron_config: MegatronConfig = (lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2))(), fsdp_config: FSDPConfig = FSDPConfig(), sequence_parallel_size: int = 1, model_config_kwargs: dict = dict(), use_torch_compile: bool = False, record_memory: bool = False, strategy: str = 'megatron', dataset_name: str = 'yahma/alpaca-cleaned', dataset_split: str = 'train[:100]', messages_key: str = 'messages', tools_key: str = 'tools', system_key: str = 'system', eval_dataset_name: Optional[str] = None, eval_dataset_split: str = 'validation', eval_interval: int = 0, eval_before_train: bool = False, max_length: Optional[int] = None, num_steps: Optional[int] = None, num_epochs: Optional[int] = 1, batch_size: int = 4, micro_train_batch_size_per_gpu: int = 2, logger: str = 'console', project_name: str = 'skyrl_sft', run_name: str = 'skyrl_sft_run', tags: Optional[List[str]] = None, ckpt_path: str = '', ckpt_interval: int = 0, enable_ray_gpu_monitor: bool = True, max_ckpts_to_keep: int = -1, resume_from: str = '', hf_save_interval: int = 0, export_path: str = '', seed: int = 42, num_workers: int = 8, cache_dir: str = os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache')), 'skyrl', 'tokenized_datasets'), force_recache: bool = False, disable_cache: bool = False, train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, remove_microbatch_padding: bool = True, use_sequence_packing: bool = False, max_tokens_per_microbatch: Optional[int] = None, dummy_run_full_ctx: bool = False, dummy_run_max_steps: int = 5) -> NoneBases: BaseConfig
Configuration for SFT training.
Usage::
cfg = SFTConfig( strategy="megatron", placement=SFTPlacementConfig(num_gpus_per_node=4), megatron_config=MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2), )
Or from CLI::
cfg = SFTConfig.from_cli_overrides(sys.argv[1:])
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
from_cli_overrides | Construct an SFTConfig from CLI arguments or a dict of overrides. |
resolved_bin_capacity | FFD bin capacity (max tokens per bin) when sequence packing is enabled. |
Attributes:
| Name | Type | Description |
|---|---|---|
model | ModelConfig | |
optimizer_config | OptimizerConfig | |
placement | SFTPlacementConfig | |
megatron_config | MegatronConfig | |
fsdp_config | FSDPConfig | |
sequence_parallel_size | int | Ulysses sequence parallelism size |
model_config_kwargs | dict | Pass-through kwargs for the HuggingFace model config (FSDP backends). |
use_torch_compile | bool | Apply torch.compile to logits calculation. |
record_memory | bool | Save memory snapshots to {ckpt_path}/memory_snapshots/. |
strategy | str | |
dataset_name | str | |
dataset_split | str | |
messages_key | str | |
tools_key | str | Column name holding per-row tool/function schemas for tool-calling datasets |
system_key | str | Column name holding a per-row system prompt to prepend when messages |
eval_dataset_name | Optional[str] | HuggingFace dataset name (or path) used to compute eval loss during training. |
eval_dataset_split | str | Split of the eval dataset to load (e.g. "validation", "test[:500]"). |
eval_interval | int | Run eval every N training steps. Eval also runs once at the end of training |
eval_before_train | bool | If True, run a baseline eval pass before training begins (logged at step 0). |
max_length | Optional[int] | Maximum length of tokenized sequences. If specified, all sequences will be truncated to this value |
num_steps | Optional[int] | Number of training steps. If None, num_epochs is used to derive the step count. |
num_epochs | Optional[int] | Number of training epochs. Used when num_steps is None. Default: 1 epoch. |
batch_size | int | |
micro_train_batch_size_per_gpu | int | |
logger | str | |
project_name | str | |
run_name | str | |
tags | Optional[List[str]] | Optional list of tags to apply to the W&B run. Has no effect on other backends. |
ckpt_path | str | |
ckpt_interval | int | |
enable_ray_gpu_monitor | bool | Enable background Ray GPU/RAM metrics collection and logging to wandb. |
max_ckpts_to_keep | int | -1 to keep all checkpoints, N to keep only the last N. |
resume_from | str | |
hf_save_interval | int | Save HuggingFace-format weights every N steps. 0 = disabled. |
export_path | str | Directory for HF-format exports. Defaults to ckpt_path/hf_exports if empty. |
seed | int | |
num_workers | int | Number of worker processes for parallel tokenization during dataset loading. Set to 0 for single-threaded. |
cache_dir | str | Directory to cache tokenized datasets. For multi-node training, set this to an NFS-mounted path so all nodes can |
force_recache | bool | If True, ignore existing cache and re-tokenize the dataset. |
disable_cache | bool | If True, disable cache completely (always tokenize from scratch). |
train_on_what | TrainOnWhat | Which tokens to compute loss on. See :class:TrainOnWhat for options. |
remove_microbatch_padding | bool | |
use_sequence_packing | bool | Enable controller-level FFD bin-packing across the global mini-batch. |
max_tokens_per_microbatch | Optional[int] | FFD bin capacity (max tokens per bin) when use_sequence_packing=True. |
dummy_run_full_ctx | bool | |
dummy_run_max_steps | int |
Source code in skyrl/train/config/sft_config.py:56-256
@dataclass
class SFTConfig(BaseConfig):
"""Configuration for SFT training.
Usage::
cfg = SFTConfig(
strategy="megatron",
placement=SFTPlacementConfig(num_gpus_per_node=4),
megatron_config=MegatronConfig(tensor_model_parallel_size=2,
pipeline_model_parallel_size=2),
)
Or from CLI::
cfg = SFTConfig.from_cli_overrides(sys.argv[1:])
"""
@classmethod
def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
"""Construct an SFTConfig from CLI arguments or a dict of overrides.
Parses CLI dotlist arguments via OmegaConf and builds a typed config.
Dataclass field defaults are used for any values not specified.
Args:
args: Either a list of CLI arguments in 'key.path=value' format, or a dict
mapping dot-notation keys to values.
Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}
Returns:
A fully constructed SFTConfig with CLI overrides applied.
Raises:
ValueError: If both ``num_epochs`` and ``num_steps`` are explicitly provided.
"""
if isinstance(args, dict):
args = [f"{k}={v}" for k, v in args.items()]
overrides = OmegaConf.from_cli(args)
# Check for mutual exclusion before constructing the full config
if "num_epochs" in overrides and "num_steps" in overrides:
raise ValueError("Cannot specify both num_epochs and num_steps")
# Accept the deprecated ``use_sample_packing`` key as an alias for
# ``remove_microbatch_padding``. Remap it before construction so the
# strict key validation does not reject the old name.
if "use_sample_packing" in overrides:
if "remove_microbatch_padding" in overrides:
raise ValueError(
"Specify only one of use_sample_packing (deprecated) and remove_microbatch_padding, not both."
)
import warnings
warnings.warn(
"use_sample_packing has been renamed to remove_microbatch_padding; "
"use remove_microbatch_padding instead.",
DeprecationWarning,
stacklevel=2,
)
overrides["remove_microbatch_padding"] = overrides["use_sample_packing"]
del overrides["use_sample_packing"]
return cls.from_dict_config(overrides)
# ---- Reused SkyRL config objects ----
model: ModelConfig = field(default_factory=lambda: ModelConfig(path="Qwen/Qwen3-0.6B"))
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)
megatron_config: MegatronConfig = field(
default_factory=lambda: MegatronConfig(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
)
)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
# Ulysses sequence parallelism
sequence_parallel_size: int = 1
"""Ulysses sequence parallelism size"""
model_config_kwargs: dict = field(default_factory=dict)
"""Pass-through kwargs for the HuggingFace model config (FSDP backends).
For Megatron, use ``megatron_config.transformer_config_kwargs`` instead."""
use_torch_compile: bool = False
"""Apply torch.compile to logits calculation."""
record_memory: bool = False
"""Save memory snapshots to ``{ckpt_path}/memory_snapshots/``.
Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz."""
# ---- SFT-specific flat fields ----
strategy: str = "megatron" # "megatron" or "fsdp"
dataset_name: str = "yahma/alpaca-cleaned"
dataset_split: str = "train[:100]"
messages_key: str = "messages" # column name for chat-format datasets
tools_key: str = "tools"
"""Column name holding per-row tool/function schemas for tool-calling datasets
(e.g. APIGen-MT, xLAM, ToolACE). May be a list[dict] or a JSON-encoded string.
Ignored if the column is absent from the dataset."""
system_key: str = "system"
"""Column name holding a per-row system prompt to prepend when ``messages``
does not already start with a system turn. Ignored if absent."""
# ---- Evaluation dataset ----
eval_dataset_name: Optional[str] = None
"""HuggingFace dataset name (or path) used to compute eval loss during training.
When ``None`` (default), eval is disabled."""
eval_dataset_split: str = "validation"
"""Split of the eval dataset to load (e.g. ``"validation"``, ``"test[:500]"``)."""
eval_interval: int = 0
"""Run eval every N training steps. Eval also runs once at the end of training
when an eval dataset is configured. ``0`` disables periodic eval."""
eval_before_train: bool = False
"""If True, run a baseline eval pass before training begins (logged at step 0)."""
max_length: Optional[int] = None
"""Maximum length of tokenized sequences. If specified, all sequences will be truncated to this value
By default, no truncation is performed"""
num_steps: Optional[int] = None
"""Number of training steps. If None, num_epochs is used to derive the step count."""
num_epochs: Optional[int] = 1
"""Number of training epochs. Used when num_steps is None. Default: 1 epoch."""
batch_size: int = 4
micro_train_batch_size_per_gpu: int = 2
logger: str = "console" # "console" or "wandb"
project_name: str = "skyrl_sft"
run_name: str = "skyrl_sft_run"
tags: Optional[List[str]] = None
"""Optional list of tags to apply to the W&B run. Has no effect on other backends."""
ckpt_path: str = ""
ckpt_interval: int = 0 # <= 0 -> no checkpointing
enable_ray_gpu_monitor: bool = True
"""Enable background Ray GPU/RAM metrics collection and logging to wandb."""
max_ckpts_to_keep: int = -1
"""-1 to keep all checkpoints, N to keep only the last N."""
resume_from: str = "" # "" = no resume, "latest" = latest checkpoint, or path to global_step_N dir
# ---- HF export ----
hf_save_interval: int = 0
"""Save HuggingFace-format weights every N steps. 0 = disabled."""
export_path: str = ""
"""Directory for HF-format exports. Defaults to ckpt_path/hf_exports if empty."""
seed: int = 42
# ---- Data loading ----
num_workers: int = 8
"""Number of worker processes for parallel tokenization during dataset loading. Set to 0 for single-threaded."""
# ---- Tokenized dataset caching ----
cache_dir: str = os.path.join(
os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")), "skyrl", "tokenized_datasets"
)
"""Directory to cache tokenized datasets. For multi-node training, set this to an NFS-mounted path so all nodes can
share the cache."""
force_recache: bool = False
"""If True, ignore existing cache and re-tokenize the dataset."""
disable_cache: bool = False
"""If True, disable cache completely (always tokenize from scratch)."""
# ---- Training target ----
train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE
"""Which tokens to compute loss on. See :class:`TrainOnWhat` for options."""
# ---- Packing ----
remove_microbatch_padding: bool = True # Pack multiple sequences per microbatch (requires flash_attn)
use_sequence_packing: bool = False
"""Enable controller-level FFD bin-packing across the global mini-batch.
Requires ``remove_microbatch_padding=True`` and the Megatron backend. When
enabled, ``SFTTrainer`` uses ``PackedDataCollator`` instead of
``DefaultCollator``. Each bin row becomes one row in the dispatched batch
and one worker micro-batch.
"""
max_tokens_per_microbatch: Optional[int] = None
"""FFD bin capacity (max tokens per bin) when ``use_sequence_packing=True``.
Each bin row becomes one worker micro-batch, so this is the token budget for
one micro-batch. Must be ``>= max_length`` so any single sequence fits in a
bin. ``None`` (default) resolves to ``max_length`` (each bin holds one
sequence)."""
# ---- Dummy run / benchmarking ----
dummy_run_full_ctx: bool = False # Skip real data; fabricate full-context sequences
dummy_run_max_steps: int = 5 # Number of steps to run in dummy mode
def resolved_bin_capacity(self) -> int:
"""FFD bin capacity (max tokens per bin) when sequence packing is enabled.
Resolves ``max_tokens_per_microbatch`` against ``max_length``: when the
token budget is ``None`` it falls back to ``max_length`` (each bin holds
one sequence). Requires ``max_length`` to be set and the resolved budget
to be ``>= max_length`` so any single sequence fits in a bin.
"""
if self.max_length is None:
raise ValueError("max_tokens_per_microbatch requires max_length to be set.")
max_tokens = self.max_tokens_per_microbatch
if max_tokens is None:
max_tokens = self.max_length
if max_tokens < self.max_length:
raise ValueError(
f"max_tokens_per_microbatch ({max_tokens}) must be >= max_length "
f"({self.max_length}) so any single sequence fits in a bin."
)
return max_tokensfrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
method classmethod from_cli_overrides
from_cli_overrides(args: Union[List[str], dict]) -> SFTConfigConstruct an SFTConfig from CLI arguments or a dict of overrides.
Parses CLI dotlist arguments via OmegaConf and builds a typed config. Dataclass field defaults are used for any values not specified.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
args | Union[List[str], dict] | Either a list of CLI arguments in 'key.path=value' format, or a dict mapping dot-notation keys to values. Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B'] Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'} | required |
Returns:
| Type | Description |
|---|---|
| SFTConfig | A fully constructed SFTConfig with CLI overrides applied. |
Raises:
| Type | Description |
|---|---|
| ValueError | If both num_epochs and num_steps are explicitly provided. |
Source code in skyrl/train/config/sft_config.py:74-118
@classmethod
def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
"""Construct an SFTConfig from CLI arguments or a dict of overrides.
Parses CLI dotlist arguments via OmegaConf and builds a typed config.
Dataclass field defaults are used for any values not specified.
Args:
args: Either a list of CLI arguments in 'key.path=value' format, or a dict
mapping dot-notation keys to values.
Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}
Returns:
A fully constructed SFTConfig with CLI overrides applied.
Raises:
ValueError: If both ``num_epochs`` and ``num_steps`` are explicitly provided.
"""
if isinstance(args, dict):
args = [f"{k}={v}" for k, v in args.items()]
overrides = OmegaConf.from_cli(args)
# Check for mutual exclusion before constructing the full config
if "num_epochs" in overrides and "num_steps" in overrides:
raise ValueError("Cannot specify both num_epochs and num_steps")
# Accept the deprecated ``use_sample_packing`` key as an alias for
# ``remove_microbatch_padding``. Remap it before construction so the
# strict key validation does not reject the old name.
if "use_sample_packing" in overrides:
if "remove_microbatch_padding" in overrides:
raise ValueError(
"Specify only one of use_sample_packing (deprecated) and remove_microbatch_padding, not both."
)
import warnings
warnings.warn(
"use_sample_packing has been renamed to remove_microbatch_padding; "
"use remove_microbatch_padding instead.",
DeprecationWarning,
stacklevel=2,
)
overrides["remove_microbatch_padding"] = overrides["use_sample_packing"]
del overrides["use_sample_packing"]
return cls.from_dict_config(overrides)attr model
model: ModelConfig = field(default_factory=(lambda: ModelConfig(path='Qwen/Qwen3-0.6B')))attr optimizer_config
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)attr placement
placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)attr megatron_config
megatron_config: MegatronConfig = field(default_factory=(lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2)))attr fsdp_config
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)attr sequence_parallel_size
sequence_parallel_size: int = 1Ulysses sequence parallelism size
attr model_config_kwargs
model_config_kwargs: dict = field(default_factory=dict)Pass-through kwargs for the HuggingFace model config (FSDP backends).
For Megatron, use megatron_config.transformer_config_kwargs instead.
attr use_torch_compile
use_torch_compile: bool = FalseApply torch.compile to logits calculation.
attr record_memory
record_memory: bool = FalseSave memory snapshots to {ckpt_path}/memory_snapshots/.
Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz.
attr strategy
strategy: str = 'megatron'attr dataset_name
dataset_name: str = 'yahma/alpaca-cleaned'attr dataset_split
dataset_split: str = 'train[:100]'attr messages_key
messages_key: str = 'messages'attr tools_key
tools_key: str = 'tools'Column name holding per-row tool/function schemas for tool-calling datasets (e.g. APIGen-MT, xLAM, ToolACE). May be a list[dict] or a JSON-encoded string. Ignored if the column is absent from the dataset.
attr system_key
system_key: str = 'system'Column name holding a per-row system prompt to prepend when messages
does not already start with a system turn. Ignored if absent.
attr eval_dataset_name
eval_dataset_name: Optional[str] = NoneHuggingFace dataset name (or path) used to compute eval loss during training.
When None (default), eval is disabled.
attr eval_dataset_split
eval_dataset_split: str = 'validation'Split of the eval dataset to load (e.g. "validation", "test[:500]").
attr eval_interval
eval_interval: int = 0Run eval every N training steps. Eval also runs once at the end of training
when an eval dataset is configured. 0 disables periodic eval.
attr eval_before_train
eval_before_train: bool = FalseIf True, run a baseline eval pass before training begins (logged at step 0).
attr max_length
max_length: Optional[int] = NoneMaximum length of tokenized sequences. If specified, all sequences will be truncated to this value By default, no truncation is performed
attr num_steps
num_steps: Optional[int] = NoneNumber of training steps. If None, num_epochs is used to derive the step count.
attr num_epochs
num_epochs: Optional[int] = 1Number of training epochs. Used when num_steps is None. Default: 1 epoch.
attr batch_size
batch_size: int = 4attr micro_train_batch_size_per_gpu
micro_train_batch_size_per_gpu: int = 2attr logger
logger: str = 'console'attr project_name
project_name: str = 'skyrl_sft'attr run_name
run_name: str = 'skyrl_sft_run'attr tags
tags: Optional[List[str]] = NoneOptional list of tags to apply to the W&B run. Has no effect on other backends.
attr ckpt_path
ckpt_path: str = ''attr ckpt_interval
ckpt_interval: int = 0attr enable_ray_gpu_monitor
enable_ray_gpu_monitor: bool = TrueEnable background Ray GPU/RAM metrics collection and logging to wandb.
attr max_ckpts_to_keep
max_ckpts_to_keep: int = -1-1 to keep all checkpoints, N to keep only the last N.
attr resume_from
resume_from: str = ''attr hf_save_interval
hf_save_interval: int = 0Save HuggingFace-format weights every N steps. 0 = disabled.
attr export_path
export_path: str = ''Directory for HF-format exports. Defaults to ckpt_path/hf_exports if empty.
attr seed
seed: int = 42attr num_workers
num_workers: int = 8Number of worker processes for parallel tokenization during dataset loading. Set to 0 for single-threaded.
attr cache_dir
cache_dir: str = os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache')), 'skyrl', 'tokenized_datasets')Directory to cache tokenized datasets. For multi-node training, set this to an NFS-mounted path so all nodes can share the cache.
attr force_recache
force_recache: bool = FalseIf True, ignore existing cache and re-tokenize the dataset.
attr disable_cache
disable_cache: bool = FalseIf True, disable cache completely (always tokenize from scratch).
attr train_on_what
train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGEWhich tokens to compute loss on. See :class:TrainOnWhat for options.
attr remove_microbatch_padding
remove_microbatch_padding: bool = Trueattr use_sequence_packing
use_sequence_packing: bool = FalseEnable controller-level FFD bin-packing across the global mini-batch.
Requires remove_microbatch_padding=True and the Megatron backend. When
enabled, SFTTrainer uses PackedDataCollator instead of
DefaultCollator. Each bin row becomes one row in the dispatched batch
and one worker micro-batch.
attr max_tokens_per_microbatch
max_tokens_per_microbatch: Optional[int] = NoneFFD bin capacity (max tokens per bin) when use_sequence_packing=True.
Each bin row becomes one worker micro-batch, so this is the token budget for
one micro-batch. Must be >= max_length so any single sequence fits in a
bin. None (default) resolves to max_length (each bin holds one
sequence).
attr dummy_run_full_ctx
dummy_run_full_ctx: bool = Falseattr dummy_run_max_steps
dummy_run_max_steps: int = 5method resolved_bin_capacity
resolved_bin_capacity() -> intFFD bin capacity (max tokens per bin) when sequence packing is enabled.
Resolves max_tokens_per_microbatch against max_length: when the
token budget is None it falls back to max_length (each bin holds
one sequence). Requires max_length to be set and the resolved budget
to be >= max_length so any single sequence fits in a bin.
Source code in skyrl/train/config/sft_config.py:238-256
def resolved_bin_capacity(self) -> int:
"""FFD bin capacity (max tokens per bin) when sequence packing is enabled.
Resolves ``max_tokens_per_microbatch`` against ``max_length``: when the
token budget is ``None`` it falls back to ``max_length`` (each bin holds
one sequence). Requires ``max_length`` to be set and the resolved budget
to be ``>= max_length`` so any single sequence fits in a bin.
"""
if self.max_length is None:
raise ValueError("max_tokens_per_microbatch requires max_length to be set.")
max_tokens = self.max_tokens_per_microbatch
if max_tokens is None:
max_tokens = self.max_length
if max_tokens < self.max_length:
raise ValueError(
f"max_tokens_per_microbatch ({max_tokens}) must be >= max_length "
f"({self.max_length}) so any single sequence fits in a bin."
)
return max_tokensTrainer
class SFTTrainer
SFTTrainer(cfg: SFTConfig, skyrl_cfg: SkyRLTrainConfig | None = None, callbacks: Optional[list[TrainingCallback]] = None)SFT trainer supporting FSDP and Megatron backends.
Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are fundamentally different: no generation, no critic, no advantages, no KL penalty. Sharing a base class would create confusing dead code paths.
Usage::
trainer = SFTTrainer(SFTConfig(strategy="megatron")) trainer.setup() trainer.train() trainer.shutdown()
Functions:
| Name | Description |
|---|---|
setup | Initialize tokenizer, workers, dispatch, and tracker. |
add_callback | Register a callback. Can be called anytime; events fired after this |
load_dataset | Load and tokenize the training dataset. |
load_eval_dataset | Load and tokenize the eval dataset, or return None if not configured. |
collate_batch | Collate examples into a TrainingInputBatch via the configured collator. |
load_checkpoint | Load a checkpoint and return the step number to resume from. |
run_eval | Compute eval loss over the full eval dataset. |
train_step | Execute a single training step: forward_backward + optim_step. |
train | Full training loop: load data, iterate, log, checkpoint. |
save_checkpoint | Save a checkpoint at the given step. Returns the checkpoint folder path. |
save_hf_model | Save policy weights in HuggingFace format. |
shutdown | Finish tracking. |
Attributes:
| Name | Type | Description |
|---|---|---|
sft_cfg | ||
cfg | ||
tokenizer | ||
dispatch | WorkerDispatch | None | |
tracker | Tracking | None | |
global_step | ||
collator |
Source code in skyrl/train/sft_trainer.py:675-1755
class SFTTrainer:
"""SFT trainer supporting FSDP and Megatron backends.
Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are
fundamentally different: no generation, no critic, no advantages, no
KL penalty. Sharing a base class would create confusing dead code paths.
Usage::
trainer = SFTTrainer(SFTConfig(strategy="megatron"))
trainer.setup()
trainer.train()
trainer.shutdown()
"""
def __init__(
self,
cfg: SFTConfig,
skyrl_cfg: SkyRLTrainConfig | None = None,
callbacks: Optional[list[TrainingCallback]] = None,
):
self.sft_cfg = cfg
# Accept a pre-built bridge config to avoid redundant rebuilds.
# When not provided (e.g. standalone usage), build it here.
self.cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)
self.tokenizer = None
self.dispatch: WorkerDispatch | None = None
self.tracker: Tracking | None = None
self.global_step = 0
# running count of total non-padding tokens trained on
self._total_tokens_processed = 0
self.collator = None # built in setup() once the tokenizer is available
self._num_training_gpus: int = cfg.placement.num_nodes * cfg.placement.num_gpus_per_node
self._ray_gpu_monitor = RayGpuMonitor() if cfg.enable_ray_gpu_monitor else None
self._callback_handler = CallbackHandler(callbacks)
self._training_control = TrainingControl()
# Loop metadata used to build CallbackInput. Populated in train().
self._total_steps: int = 0
self._steps_per_epoch: int = 0
self._current_epoch: int = 0
def _build_collator(self, tokenizer):
"""Select the batch collator from the configured packing mode.
``PackedDataCollator`` performs controller-level FFD bin-packing
(Megatron-only, ``use_sequence_packing=True``); ``DefaultCollator``
left-pads each example. The choice is fixed by static config; the
``tokenizer`` is passed in by :meth:`setup` once it is available. The
packed config is validated here.
"""
# Imported lazily to avoid a circular import: ``collators`` imports
# ``collate_sft_batch`` from this module.
from skyrl.train.dataset.collators import DefaultCollator, PackedDataCollator
if self.sft_cfg.use_sequence_packing:
self._validate_packing_cfg()
return PackedDataCollator(
tokenizer=tokenizer,
max_tokens_per_microbatch=self.sft_cfg.resolved_bin_capacity(),
tp_size=self.sft_cfg.megatron_config.tensor_model_parallel_size,
pp_size=self.sft_cfg.megatron_config.pipeline_model_parallel_size,
cp_size=self.sft_cfg.megatron_config.context_parallel_size,
dp_size=self._dp_size(),
batch_size=self.sft_cfg.batch_size,
micro_train_batch_size_per_gpu=self.sft_cfg.micro_train_batch_size_per_gpu,
)
return DefaultCollator(
tokenizer=tokenizer,
micro_train_batch_size_per_gpu=self.sft_cfg.micro_train_batch_size_per_gpu,
)
def _dp_size(self) -> int:
"""Number of DP ranks under the configured Megatron parallelism."""
total_gpus = self.sft_cfg.placement.num_nodes * self.sft_cfg.placement.num_gpus_per_node
tp = self.sft_cfg.megatron_config.tensor_model_parallel_size
pp = self.sft_cfg.megatron_config.pipeline_model_parallel_size
cp = self.sft_cfg.megatron_config.context_parallel_size
return total_gpus // (tp * pp * cp)
def _validate_packing_cfg(self):
"""Validate the config when ``use_sequence_packing=True``."""
if self.sft_cfg.strategy != "megatron":
raise ValueError(
f"use_sequence_packing=True only supports strategy='megatron'; got "
f"{self.sft_cfg.strategy!r}. Use the FSDP packing path instead."
)
# Sequence packing needs the THD layout, so it implies
# remove_microbatch_padding=True. Auto-enable it (warning if the user
# explicitly set it False) instead of erroring on the contradiction.
if not self.sft_cfg.remove_microbatch_padding:
logger.warning(
"use_sequence_packing=True requires the THD layout; "
"setting remove_microbatch_padding=True (was False)."
)
self.sft_cfg.remove_microbatch_padding = True
# ------------------------------------------------------------------ #
# Setup
# ------------------------------------------------------------------ #
def setup(self):
"""Initialize tokenizer, workers, dispatch, and tracker.
Ray must already be initialized before calling this (either via
``initialize_ray`` on the head node or inside a Ray task).
"""
self.tokenizer = get_tokenizer(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
padding_side="left",
)
self.collator = self._build_collator(self.tokenizer)
self._init_tracker()
self._init_workers()
def _init_workers(self):
"""Create PPORayActorGroup and WorkerDispatch.
Selects the correct PolicyWorker based on strategy.
"""
if self.sft_cfg.strategy == "megatron":
from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
PolicyWorker,
)
else:
from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
num_gpus = self.sft_cfg.placement.num_gpus_per_node
raw_pg = placement_group(
[{"GPU": num_gpus, "CPU": num_gpus}] * self.sft_cfg.placement.num_nodes,
strategy="PACK",
)
get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
pg = ResolvedPlacementGroup(raw_pg)
actor_group = PPORayActorGroup(
self.cfg.trainer,
num_nodes=self.sft_cfg.placement.num_nodes,
num_gpus_per_node=num_gpus,
ray_actor_type=PolicyWorker,
pg=pg,
num_gpus_per_actor=1,
colocate_all=False,
sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size,
record_memory=self.cfg.trainer.policy.record_memory,
)
num_training_steps = (
self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps
)
# num_steps may be None when num_epochs is used; the worker will use its
# default (large value) for the LR scheduler in that case.
ray.get(
actor_group.async_init_model(
self.sft_cfg.model.path,
num_training_steps=num_training_steps,
)
)
ray.get(actor_group.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
self.dispatch = WorkerDispatch(self.cfg, policy_actor_group=actor_group)
def _init_tracker(self):
self.tracker = Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
config=self.sft_cfg,
tags=self.cfg.trainer.tags,
)
def add_callback(self, callback: TrainingCallback) -> None:
"""Register a callback. Can be called anytime; events fired after this
call will reach the new callback."""
self._callback_handler.add(callback)
def _build_callback_input(self, **fields) -> CallbackInput:
"""Snapshot loop counters + per-event fields into a CallbackInput."""
return CallbackInput(
global_step=self.global_step,
epoch=self._current_epoch,
total_steps=self._total_steps,
steps_per_epoch=self._steps_per_epoch,
**fields,
)
def _fire(self, event_name: str, **fields) -> None:
"""Build a CallbackInput and dispatch the given event to all callbacks."""
cb_input = self._build_callback_input(**fields)
getattr(self._callback_handler, event_name)(self, cb_input, self._training_control)
# ------------------------------------------------------------------ #
# Data
# ------------------------------------------------------------------ #
def _load_and_tokenize(self, dataset_name: str, dataset_split: str) -> list:
"""Load and tokenize a dataset with caching support.
Auto-detects the dataset format based on column names:
- If a ``messages_key`` column exists, uses chat-format tokenization.
- If ``instruction`` and ``output`` columns exist, uses Alpaca-format
tokenization.
Uses manual multiprocessing for parallel tokenization when num_workers > 0.
With parallel mode, uses slice-based loading where each worker loads its
own data slice directly from HuggingFace to eliminate pickle overhead.
Caching:
- Tokenized datasets are cached to disk as a HuggingFace ``Dataset``
(arrow-backed, memory-mapped) for reuse across runs.
- Cache key is a hash of dataset name, split, model, and tokenization params.
- Set ``force_recache=True`` to ignore cache and re-tokenize.
- Set ``disable_cache=True`` to disable caching entirely.
Args:
dataset_name: HuggingFace dataset name (e.g. ``"yahma/alpaca-cleaned"``).
dataset_split: Dataset split (e.g. ``"train[:100]"`` or ``"test"``).
Returns a list of tokenized examples (dicts with ``input_ids``,
``attention_mask``, ``num_actions``).
"""
# Check cache first (unless disabled or force_recache)
if not self.sft_cfg.disable_cache:
cache_dir = self.sft_cfg.cache_dir
# Compute cache key
tools_key = self.sft_cfg.tools_key if self.sft_cfg.tools_key else None
system_key = self.sft_cfg.system_key if self.sft_cfg.system_key else None
cache_key = _compute_cache_key(
dataset_name=dataset_name,
dataset_split=dataset_split,
model_path=self.sft_cfg.model.path,
max_length=self.sft_cfg.max_length,
messages_key=self.sft_cfg.messages_key,
train_on_what=self.sft_cfg.train_on_what.value,
tools_key=tools_key,
system_key=system_key,
)
cache_path = _get_cache_path(cache_dir, cache_key)
# Try to load from cache (unless force_recache)
if not self.sft_cfg.force_recache:
cached = _load_from_cache(cache_path)
if cached is not None:
return cached
logger.info("Cache miss or force_recache=True, tokenizing dataset...")
logger.info(f"Cache key: {cache_key}")
logger.info(f"Loading dataset '{dataset_name}' split='{dataset_split}'...")
dataset = load_dataset(dataset_name, split=dataset_split)
columns = dataset.column_names
num_workers = self.sft_cfg.num_workers
# Sequential tokenization path
if num_workers == 0:
logger.info("Tokenizing dataset (sequential)...")
if self.sft_cfg.messages_key in columns:
tools_key = self.sft_cfg.tools_key if self.sft_cfg.tools_key in columns else None
system_key = self.sft_cfg.system_key if self.sft_cfg.system_key in columns else None
tokenized = [
tokenize_chat_example(
ex,
self.tokenizer,
self.sft_cfg.max_length,
self.sft_cfg.messages_key,
train_on_what=self.sft_cfg.train_on_what,
tools_key=tools_key,
system_key=system_key,
)
for ex in dataset
]
elif "instruction" in columns and "output" in columns:
tokenized = [tokenize_sft_example(ex, self.tokenizer, self.sft_cfg.max_length) for ex in dataset]
else:
raise ValueError(
f"Unrecognized dataset format. Expected '{self.sft_cfg.messages_key}' column "
f"(chat format) or 'instruction'+'output' columns (Alpaca format). "
f"Found columns: {columns}"
)
tokenized = [ex for ex in tokenized if ex is not None]
logger.info(f"Tokenized {len(tokenized)} examples (filtered from {len(dataset)})")
# Save to cache if enabled
if not self.sft_cfg.disable_cache:
# TODO (sumanthrh): Currently we use a simple list instead of dataset + stateful dataloader
# for simplicity but for caching we use HF Dataset since file sizes can get large
# We should migrate to using HF datasets + a dataloader so that we don't materialize
# the full dataset in memory
_save_to_cache(cache_path, tokenized)
return tokenized
# Parallel tokenization path with slice-based loading
logger.info(f"Tokenizing dataset with {num_workers} workers (slice-based loading)...")
# Cache tokenizer to temp dir for fast worker loading
tokenizer_cache_dir = tempfile.mkdtemp(prefix="skyrl_tokenizer_")
try:
self.tokenizer.save_pretrained(tokenizer_cache_dir)
# Slice the already-loaded dataset; the original split string is
# forwarded to workers verbatim so HF parses it (no local regex).
dataset_size = len(dataset)
chunk_size = max(1, dataset_size // num_workers)
# Generate worker slice boundaries
worker_args = []
for worker_idx in range(num_workers):
worker_start = worker_idx * chunk_size
# Last worker takes any remainder
if worker_idx == num_workers - 1:
worker_end = dataset_size
else:
worker_end = min((worker_idx + 1) * chunk_size, dataset_size)
# Skip empty slices
if worker_start >= worker_end:
continue
# Prepare worker arguments based on format
if self.sft_cfg.messages_key in columns:
tools_key = self.sft_cfg.tools_key if self.sft_cfg.tools_key in columns else None
system_key = self.sft_cfg.system_key if self.sft_cfg.system_key in columns else None
worker_args.append(
(
dataset_name,
dataset_split,
worker_start,
worker_end,
tokenizer_cache_dir,
self.sft_cfg.max_length,
self.sft_cfg.messages_key,
self.sft_cfg.train_on_what.value,
tools_key,
system_key,
)
)
elif "instruction" in columns and "output" in columns:
worker_args.append(
(
dataset_name,
dataset_split,
worker_start,
worker_end,
tokenizer_cache_dir,
self.sft_cfg.max_length,
)
)
else:
raise ValueError(
f"Unrecognized dataset format. Expected '{self.sft_cfg.messages_key}' column "
f"(chat format) or 'instruction'+'output' columns (Alpaca format). "
f"Found columns: {columns}"
)
# Select worker function based on format
if self.sft_cfg.messages_key in columns:
worker_fn = _tokenize_chat_slice_worker
else:
worker_fn = _tokenize_alpaca_slice_worker
logger.info(f"Dividing {dataset_size} examples among {len(worker_args)} workers")
# Use spawn to avoid Ray fork issues
ctx = mp.get_context("spawn")
# Process in parallel
with ctx.Pool(processes=num_workers) as pool:
results = pool.map(worker_fn, worker_args)
# Flatten results
tokenized = []
for chunk_results in results:
tokenized.extend(chunk_results)
logger.info(f"Tokenized {len(tokenized)} examples (filtered from {dataset_size})")
# Save to cache if enabled
if not self.sft_cfg.disable_cache:
_save_to_cache(cache_path, tokenized)
return tokenized
finally:
# Cleanup temp tokenizer cache
import shutil
shutil.rmtree(tokenizer_cache_dir, ignore_errors=True)
def load_dataset(self) -> list:
"""Load and tokenize the training dataset."""
return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)
def load_eval_dataset(self) -> Optional[list]:
"""Load and tokenize the eval dataset, or return ``None`` if not configured."""
if not self.sft_cfg.eval_dataset_name:
return None
return self._load_and_tokenize(self.sft_cfg.eval_dataset_name, self.sft_cfg.eval_dataset_split)
def _log_dataset_stats(self, tokenized: list) -> None:
"""Log tokenized sequence length statistics over the training set.
Reports count, mean, median (q50), q25, q75, min, max of the tokenized
``input_ids`` lengths. Logs once via ``logger.info``.
"""
if not tokenized:
logger.warning("No tokenized examples to compute stats over")
return
lengths = [len(ex["input_ids"]) for ex in tokenized]
n = len(lengths)
sorted_lengths = sorted(lengths)
def pct(p: float) -> int:
# Simple nearest-rank percentile over ints; adequate for dataset stats.
idx = max(0, min(n - 1, int(round((p / 100.0) * (n - 1)))))
return sorted_lengths[idx]
mean_len = sum(lengths) / n
q25 = pct(25)
q50 = pct(50)
q75 = pct(75)
min_len = sorted_lengths[0]
max_len = sorted_lengths[-1]
logger.info(
f"Dataset stats (tokenized lengths over {n} examples):\n"
f"total={sum(lengths)}, mean={mean_len:.1f}, median={q50}, q25={q25}, q75={q75}, min={min_len}, max={max_len}"
)
def collate_batch(self, examples: list, batch_size: int) -> TrainingInputBatch:
"""Collate examples into a TrainingInputBatch via the configured collator.
Delegates to ``self.collator`` (``DefaultCollator`` or, when sequence
packing is enabled, ``PackedDataCollator``).
Args:
examples: Tokenized examples to collate.
batch_size: Global batch dimension. The train path passes
``sft_cfg.batch_size`` and the eval path passes its
per-dispatch chunk size.
"""
return self.collator(examples, batch_size=batch_size)
# ------------------------------------------------------------------ #
# Checkpoint resume
# ------------------------------------------------------------------ #
def load_checkpoint(self) -> int:
"""Load a checkpoint and return the step number to resume from.
Behaviour depends on ``sft_cfg.resume_from``:
- ``""`` (empty): no resume, return 0.
- ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
- otherwise: treat as a direct path to a ``global_step_N`` directory.
Returns:
The global step to resume from (0 if no checkpoint loaded).
"""
resume_from = self.sft_cfg.resume_from
if not resume_from:
return 0
if resume_from == "latest":
if not self.sft_cfg.ckpt_path:
logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
return 0
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_file):
logger.info("No latest checkpoint marker found, starting from scratch")
return 0
with io.open_file(latest_file, "r") as f:
ckpt_step = int(f.read().strip())
checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
# Validate consistency: ensure no stale checkpoint folders from prior runs
validate_consistency_for_latest_checkpoint(
self.sft_cfg.ckpt_path,
ckpt_step,
checkpoint_path,
latest_file,
self.sft_cfg.ckpt_interval,
)
else:
checkpoint_path = resume_from
if not io.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
global_step = extract_step_from_path(checkpoint_path)
if global_step == -1:
raise ValueError(
f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
)
# Load and validate trainer state if available
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
if io.exists(trainer_state_path):
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(
f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
)
else:
logger.warning(
f"No trainer_state.pt found at {trainer_state_path}. "
"This checkpoint was likely saved by an older version."
)
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info(f"Successfully resumed from global_step_{global_step}")
return global_step
# ------------------------------------------------------------------ #
# Training
# ------------------------------------------------------------------ #
def run_eval(self, eval_tokenized: list) -> tuple[dict, int]:
"""Compute eval loss over the full eval dataset.
Iterates the eval dataset in chunks of ``micro_train_batch_size_per_gpu * dp_size``
(i.e. exactly one micro-batch per DP rank per dispatch call), calls
:meth:`WorkerDispatch.forward` with ``loss_fn="cross_entropy"`` (which
runs the model in ``eval()`` mode under ``no_grad``), and aggregates the
per-batch losses into a token-weighted mean.
The aggregated loss is a token-weighted mean of the per-batch losses,
which are themselves per-non-pad-token means within each batch. This
yields the true per-non-pad-token mean across the eval dataset.
Args:
eval_tokenized: Pre-tokenized eval dataset (output of
:meth:`load_eval_dataset`).
Returns:
``(metrics, num_eval_batches)`` where ``metrics`` contains
``eval_loss`` and ``num_eval_batches`` is bookkeeping for
stdout logging (not a wandb metric).
"""
num_eval = len(eval_tokenized)
if num_eval == 0:
raise ValueError(
"Eval dataset is empty. Provide a non-empty eval split or disable eval "
"by setting eval_dataset_name=None."
)
# One micro-batch per DP rank per dispatch call — keeps memory usage bounded
# and removes the need for a separate `eval_batch_size` knob.
dp_size = self.dispatch.dp_size("policy")
eval_chunk_size = self.sft_cfg.micro_train_batch_size_per_gpu * dp_size
# Pad a trailing partial batch up to ``eval_chunk_size`` via
# ``pad_training_input_batch`` (which zeros ``loss_mask`` on padded rows).
# Padded rows contribute 0 to the cross-entropy numerator, and the
# pre-padding ``total_nonpad`` scaling in ``collate_batch`` excludes
# them from the denominator, so the reported ``eval_loss`` is the
# per-real-token mean over the full (non-padded) eval set.
num_eval_batches = ceil(num_eval / eval_chunk_size)
total_loss_weighted = 0.0
total_tokens = 0
for batch_idx in range(num_eval_batches):
start = batch_idx * eval_chunk_size
end = min(start + eval_chunk_size, num_eval)
batch_examples = eval_tokenized[start:end]
batch = self.collator(batch_examples, batch_size=eval_chunk_size)
# Pad the last (possibly-short) chunk so every dispatch sees exactly
# ``eval_chunk_size`` rows. ``pad_training_input_batch`` zeros the
# ``loss_mask`` for padding rows; with ``pad_size=0`` it is a no-op.
pad_rows = eval_chunk_size - len(batch_examples)
if pad_rows > 0:
logger.info(
f"Padding final eval batch by {pad_rows} rows "
f"({len(batch_examples)} real -> {eval_chunk_size} total); "
f"padded rows are masked out of the loss."
)
batch = pad_training_input_batch(batch, pad_rows)
# Count non-pad response tokens (from the unscaled mask, recovered from the batch)
# We use the attention_mask response window via collate_sft_batch's loss_mask which
# was 0/1 before scaling. Recover the count from the batch by counting positive entries.
# Padded rows have loss_mask=0 so they are excluded here.
nonpad_tokens = int((batch["loss_mask"] > 0).sum().item())
output = self.dispatch.forward(
"policy",
batch,
loss_fn="cross_entropy",
loss_fn_config=None,
)
batch_loss = float(output.metrics.get("loss", float("nan")))
total_loss_weighted += batch_loss * nonpad_tokens
total_tokens += nonpad_tokens
eval_loss = total_loss_weighted / max(total_tokens, 1)
return {"eval_loss": eval_loss}, num_eval_batches
def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
"""Execute a single training step: forward_backward + optim_step.
Args:
batch: The collated training batch.
step: Current global step (reserved for future use, e.g. scheduling).
Returns:
Dict with ``loss``, ``grad_norm``, and ``timings``.
"""
timings: dict[str, float] = {}
with Timer("forward_backward", timings):
output = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
with Timer("optim_step", timings):
grad_norm = self.dispatch.optim_step("policy")
metrics = output.metrics
loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
return {
"loss": loss_val,
"grad_norm": grad_norm,
"timings": timings,
}
def _validate_batch_parallelism(self):
"""Validate that batch_size is compatible with data-parallel and micro-batch sizes."""
batch_size = self.sft_cfg.batch_size
total_gpus = self.sft_cfg.placement.num_nodes * self.sft_cfg.placement.num_gpus_per_node
if self.sft_cfg.use_sequence_packing:
# With packing, batch_size is the *example* count (not bins) and the
# per-DP-rank bin count == bins_per_shard. The worker micro batch
# size refers to bin rows per micro-batch, derived from the
# ``max_tokens_per_microbatch`` token budget. We only require
# batch_size >= dp_size (every DP rank needs >= 1 bin) and do NOT
# require batch_size % micro_train_batch_size_per_gpu == 0, because
# micro_train_batch_size_per_gpu refers to bins-per-MB, not
# examples-per-MB; FFD rounds the bin count up to a multiple of
# dp_size, and bins/MB is a separate knob.
dp_size = self._dp_size()
if batch_size < dp_size:
raise ValueError(
f"batch_size ({batch_size}) must be >= dp_size ({dp_size}) when "
f"use_sequence_packing=True (each DP rank needs at least one bin)."
)
return
if self.sft_cfg.strategy == "megatron":
tp = self.sft_cfg.megatron_config.tensor_model_parallel_size
pp = self.sft_cfg.megatron_config.pipeline_model_parallel_size
dp_size = total_gpus // (tp * pp)
else:
# FSDP: all GPUs are data-parallel
dp_size = total_gpus
if batch_size % dp_size != 0:
raise ValueError(f"batch_size ({batch_size}) must be divisible by data-parallel size ({dp_size})")
per_dp_batch = batch_size // dp_size
micro_batch = self.sft_cfg.micro_train_batch_size_per_gpu
if per_dp_batch % micro_batch != 0:
raise ValueError(
f"batch_size ({self.sft_cfg.batch_size}) / dp_size ({dp_size}) must be divisible by "
f"micro_train_batch_size_per_gpu ({micro_batch})"
)
def _build_dummy_batch(self) -> TrainingInputBatch:
"""Build a dummy batch of random full-context sequences for benchmarking."""
batch_size = self.sft_cfg.batch_size
max_length = self.sft_cfg.max_length
micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
vocab_size = self.tokenizer.vocab_size
# num_actions is max_length - 1 because the autoregressive model
# produces log-probs for positions 1..T (predicting next token),
# so the first token has no corresponding log-prob.
num_actions = max_length - 1
sequences = torch.randint(0, vocab_size, (batch_size, max_length), dtype=torch.long)
attention_mask = torch.ones(batch_size, max_length, dtype=torch.long)
# All tokens are non-pad in the dummy batch, so total_nonpad = batch_size * num_actions.
# Scaling = batch_size / (micro_batch_size * total_nonpad)
# = 1 / (micro_batch_size * num_actions)
total_nonpad = batch_size * num_actions
loss_mask = torch.ones(batch_size, num_actions, dtype=torch.float) * (
batch_size / (micro_batch_size * total_nonpad)
)
batch = TrainingInputBatch(
{
"sequences": sequences,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
}
)
batch.metadata = {"response_length": num_actions}
return batch
def _train_dummy(self):
"""Dummy training loop for benchmarking. Skips real data, checkpoints, and resume."""
self._validate_batch_parallelism()
batch = self._build_dummy_batch()
num_steps = self.sft_cfg.dummy_run_max_steps
logger.info(
f"Starting dummy SFT training for {num_steps} steps "
f"(batch_size={self.sft_cfg.batch_size}, max_length={self.sft_cfg.max_length})..."
)
if self._ray_gpu_monitor is not None:
self._ray_gpu_monitor.start()
for step in range(num_steps):
all_timings: dict[str, float] = {}
with Timer("step", all_timings):
step_result = self.train_step(batch, step)
all_timings.update(step_result["timings"])
actual_num_tokens = batch["attention_mask"].sum().item()
self._total_tokens_processed += actual_num_tokens
tokens_per_second = actual_num_tokens / all_timings["step"]
log_dict = {
"train/loss": step_result["loss"],
"train/grad_norm": step_result["grad_norm"],
"train/tokens_per_second": tokens_per_second,
"train/tokens_per_second_per_gpu": tokens_per_second / self._num_training_gpus,
"train/actual_num_tokens": actual_num_tokens,
"train/total_tokens_processed": self._total_tokens_processed,
}
log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
if self._ray_gpu_monitor is not None:
log_dict.update(self._ray_gpu_monitor.flush())
self.tracker.log(log_dict, step=step, commit=True)
logger.info(
f"Step {step}: loss={step_result['loss']:.4f}, "
f"grad_norm={step_result['grad_norm']}, "
f"tokens_per_second={tokens_per_second:.0f}"
)
logger.info("Dummy SFT training complete!")
def train(self):
"""Full training loop: load data, iterate, log, checkpoint."""
if self.sft_cfg.dummy_run_full_ctx:
if self.sft_cfg.resume_from:
logger.warning("resume_from is ignored in dummy run mode")
return self._train_dummy()
tokenized = self.load_dataset()
# Log tokenized sequence length statistics (once, before training loop)
self._log_dataset_stats(tokenized)
# Load eval dataset (if configured). We load once up-front so the
# tokenization cost is amortized across all eval invocations.
eval_tokenized = self.load_eval_dataset()
if eval_tokenized is not None:
logger.info(f"Eval dataset loaded: {len(eval_tokenized)} examples")
batch_size = self.sft_cfg.batch_size
# steps_per_epoch is always derived from the data; callbacks rely on it.
steps_per_epoch = max(1, ceil(len(tokenized) / batch_size))
# Resolve num_steps: explicit num_steps takes precedence; otherwise derive from num_epochs.
if self.sft_cfg.num_steps is not None:
num_steps = self.sft_cfg.num_steps
else:
num_steps = self.sft_cfg.num_epochs * steps_per_epoch
logger.info(
f"num_steps not set; deriving from num_epochs={self.sft_cfg.num_epochs}: "
f"ceil({len(tokenized)} / {batch_size}) * {self.sft_cfg.num_epochs} = {num_steps} steps"
)
# Early validation: dataset must have at least batch_size examples
if len(tokenized) < batch_size:
raise ValueError(
f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
f"Reduce batch_size or use more data."
)
self._validate_batch_parallelism()
# Resume from checkpoint if configured
start_step = self.load_checkpoint()
# Shuffle data before training
rng = random.Random(self.sft_cfg.seed)
rng.shuffle(tokenized)
# When resuming, start_step is the last *completed* step (checkpoint is
# saved AFTER the optimizer update), so we begin at start_step + 1 to
# avoid replaying that step.
# Replay epoch shuffles for reproducibility on resume
start_epoch = (start_step * batch_size) // len(tokenized)
for _ in range(start_epoch):
rng.shuffle(tokenized)
current_epoch = start_epoch
# Initialize `global_step`
self.global_step = start_step
# Publish loop metadata so CallbackInput can be built consistently.
self._total_steps = num_steps
self._steps_per_epoch = steps_per_epoch
self._current_epoch = current_epoch
self._training_control.reset()
logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
if start_step > 0:
logger.info(f"Resuming from step {start_step}")
if self._ray_gpu_monitor is not None:
self._ray_gpu_monitor.start()
# Tracks whether the most recent in-loop iteration saved a checkpoint
# (either via the ckpt_interval or via a callback-driven ``should_save``).
did_save_last_step = False
self._fire("on_train_start")
# Baseline eval before training begins (logged at step 0).
# Wandb's step counter starts at 0; the training loop's first commit
# advances it to >=1, so step=0 here does not conflict with later steps.
if self.sft_cfg.eval_before_train and eval_tokenized is not None:
self._fire("on_eval_start")
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
self._fire("on_eval_end", metrics=eval_metrics)
baseline_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
self._fire("on_log", logs=baseline_log)
self.tracker.log(baseline_log, step=self.global_step, commit=True)
logger.info(
f"Baseline eval before training: "
f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
f"over {num_eval_batches} batches"
)
# SkyRL starts counting at step 1
self.global_step = start_step + 1 if start_step > 0 else 1
self._fire("on_epoch_start")
epoch_in_progress = True
while self.global_step <= num_steps:
all_timings: dict[str, float] = {}
with Timer("step", all_timings):
# Data loading with wrap-around
with Timer("data_loading", all_timings):
start_idx = (self.global_step * batch_size) % len(tokenized)
end_idx = start_idx + batch_size
if end_idx > len(tokenized):
batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
else:
batch_examples = tokenized[start_idx:end_idx]
batch = self.collator(batch_examples, batch_size=batch_size)
self._fire("on_step_start", batch=batch)
# Training step
step_result = self.train_step(batch, self.global_step)
all_timings.update(step_result["timings"])
# Compute throughput using actual (non-padding) tokens
batch_padded_seq_len = batch["sequences"].shape[1]
actual_num_tokens = batch["attention_mask"].sum().item()
self._total_tokens_processed += actual_num_tokens
tokens_per_second = actual_num_tokens / all_timings["step"]
# Build log dict
log_dict = {
"train/loss": step_result["loss"],
"train/grad_norm": step_result["grad_norm"],
"train/tokens_per_second": tokens_per_second,
"train/tokens_per_second_per_gpu": tokens_per_second / self._num_training_gpus,
"train/actual_num_tokens": actual_num_tokens,
"train/batch_padded_seq_len": batch_padded_seq_len,
"train/total_tokens_processed": self._total_tokens_processed,
}
log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
if self._ray_gpu_monitor is not None:
log_dict.update(self._ray_gpu_monitor.flush())
self._fire("on_step_end", batch=batch, metrics=step_result)
# Capture callback-driven triggers, then reset so they only fire once.
force_save = self._training_control.should_save
force_eval = self._training_control.should_evaluate
self._training_control.should_save = False
self._training_control.should_evaluate = False
# Checkpoint: interval-driven or callback-requested.
interval_save = (
self.sft_cfg.ckpt_interval > 0
and self.global_step > 0
and self.global_step % self.sft_cfg.ckpt_interval == 0
)
did_save_last_step = force_save or interval_save
if did_save_last_step:
with Timer("save_checkpoint", all_timings):
ckpt_path = self.save_checkpoint()
log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]
self._fire("on_save", ckpt_path=ckpt_path)
# HF export at regular intervals
if self.sft_cfg.hf_save_interval > 0 and self.global_step % self.sft_cfg.hf_save_interval == 0:
with Timer("save_hf_model", all_timings):
self.save_hf_model()
log_dict["timing/save_hf_model"] = all_timings["save_hf_model"]
eval_metrics = None
num_eval_batches: int | None = None
# Eval fires at step N where N % eval_interval == 0 and N > 0, OR
# whenever a callback set ``control.should_evaluate``.
interval_eval = self.sft_cfg.eval_interval > 0 and self.global_step % self.sft_cfg.eval_interval == 0
if eval_tokenized is not None and (force_eval or interval_eval):
self._fire("on_eval_start")
with Timer("eval", all_timings):
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
self._fire("on_eval_end", metrics=eval_metrics)
if eval_metrics:
log_dict.update({f"eval/{k}": v for k, v in eval_metrics.items()})
log_dict["timing/eval"] = all_timings["eval"]
log_dict.update({"train/epoch": current_epoch, "train/global_step": self.global_step})
# Callbacks may mutate log_dict in place via on_log.
self._fire("on_log", logs=log_dict)
self.tracker.log(log_dict, step=self.global_step, commit=True)
if self.global_step % 5 == 0:
logger.info(
f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
)
if eval_metrics:
logger.info(
f"Step {self.global_step}: eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
f"over {num_eval_batches} batches"
)
# Check for epoch boundary and reshuffle
epoch = (self.global_step * batch_size) // len(tokenized)
if epoch > current_epoch:
self._fire("on_epoch_end")
epoch_in_progress = False
for _ in range(epoch - current_epoch):
rng.shuffle(tokenized)
current_epoch = epoch
self._current_epoch = epoch
if self.global_step + 1 <= num_steps:
self._fire("on_epoch_start")
epoch_in_progress = True
self.global_step += 1
self.global_step = min(self.global_step, num_steps)
# Pair the leading on_epoch_start: fire on_epoch_end if we exited the
# loop mid-epoch
if epoch_in_progress:
self._fire("on_epoch_end")
epoch_in_progress = False
# Save final checkpoint (if checkpointing is enabled). Skip if the last
# in-loop iteration already saved (either via ckpt_interval or via a
# callback-driven force-save) so we don't double-save.
if self.sft_cfg.ckpt_path and not did_save_last_step:
final_step = num_steps
logger.info(f"Saving final checkpoint at step {final_step}")
ckpt_path = self.save_checkpoint()
self._fire("on_save", ckpt_path=ckpt_path)
# Save final HF model if enabled (only if not already saved at last step)
if self.sft_cfg.hf_save_interval > 0:
final_step = num_steps
already_saved = final_step % self.sft_cfg.hf_save_interval == 0
if not already_saved:
self.global_step = final_step
logger.info(f"Saving final HF model at step {final_step}")
self.save_hf_model()
# Final eval pass (skip if the last step already ran eval).
# NOTE: The last in-loop tracker.log(..., commit=True) at step=num_steps
# advanced wandb's internal step counter to num_steps+1. Logging the
# final eval at step=num_steps would be rejected by wandb with
# "step N < current step N+1". We log the final eval at num_steps+1
# (one past the last committed train step) in a single combined
# tracker.log() call, preserving wandb step ordering. We use a local
# ``final_eval_step`` rather than mutating ``self.global_step``: the
# bump is purely a wandb-step accounting concern, not real trainer
# state.
if eval_tokenized is not None:
already_ran = self.sft_cfg.eval_interval > 0 and num_steps % self.sft_cfg.eval_interval == 0
if not already_ran:
final_eval_step = num_steps + 1
eval_timings: dict[str, float] = {}
self._fire("on_eval_start")
with Timer("eval", eval_timings):
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
self._fire("on_eval_end", metrics=eval_metrics)
if eval_metrics:
eval_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
eval_log["timing/eval"] = eval_timings["eval"]
self._fire("on_log", logs=eval_log)
self.tracker.log(eval_log, step=final_eval_step, commit=True)
logger.info(
f"Final eval at step {final_eval_step}: "
f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
f"over {num_eval_batches} batches"
)
self._fire("on_train_end")
logger.info("SFT training complete!")
def save_checkpoint(self) -> str:
"""Save a checkpoint at the given step. Returns the checkpoint folder path."""
step = self.global_step
global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
io.makedirs(global_step_folder, exist_ok=True)
logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
trainer_state = {
"global_step": step,
"config": asdict(self.sft_cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking -- write this last after all saves succeed
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_file, "w") as f:
f.write(str(step))
logger.info(f"Checkpoint saved for global_step_{step}")
# Clean up old checkpoints after successful save
cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)
return global_step_folder
def save_hf_model(self):
"""Save policy weights in HuggingFace format.
Export path: cfg.trainer.export_path/global_step_{step}/policy
Mirrors the pattern used by the RL trainer's save_models().
"""
step = self.global_step
policy_export_dir = os.path.join(
self.cfg.trainer.export_path,
f"{GLOBAL_STEP_PREFIX}{step}",
"policy",
)
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
logger.info(f"Saved HF model weights at step {step} to {policy_export_dir}")
# ------------------------------------------------------------------ #
# Lifecycle
# ------------------------------------------------------------------ #
def shutdown(self):
"""Finish tracking.
Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
(the normal path via ``sft_entrypoint``), shutting down Ray from
within the task would be incorrect. The head-node process owns
the Ray lifecycle.
"""
if self._ray_gpu_monitor is not None:
self._ray_gpu_monitor.stop()
if self.tracker is not None:
self.tracker.finish()attr sft_cfg
sft_cfg = cfgattr cfg
cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)attr tokenizer
tokenizer = Noneattr dispatch
dispatch: WorkerDispatch | None = Noneattr tracker
tracker: Tracking | None = Noneattr global_step
global_step = 0attr collator
collator = Nonemethod setup
setup()Initialize tokenizer, workers, dispatch, and tracker.
Ray must already be initialized before calling this (either via
initialize_ray on the head node or inside a Ray task).
Source code in skyrl/train/sft_trainer.py:777-791
def setup(self):
"""Initialize tokenizer, workers, dispatch, and tracker.
Ray must already be initialized before calling this (either via
``initialize_ray`` on the head node or inside a Ray task).
"""
self.tokenizer = get_tokenizer(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
padding_side="left",
)
self.collator = self._build_collator(self.tokenizer)
self._init_tracker()
self._init_workers()method add_callback
add_callback(callback: TrainingCallback) -> NoneRegister a callback. Can be called anytime; events fired after this call will reach the new callback.
Source code in skyrl/train/sft_trainer.py:848-851
def add_callback(self, callback: TrainingCallback) -> None:
"""Register a callback. Can be called anytime; events fired after this
call will reach the new callback."""
self._callback_handler.add(callback)method load_dataset
load_dataset() -> listLoad and tokenize the training dataset.
Source code in skyrl/train/sft_trainer.py:1068-1070
def load_dataset(self) -> list:
"""Load and tokenize the training dataset."""
return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)method load_eval_dataset
load_eval_dataset() -> Optional[list]Load and tokenize the eval dataset, or return None if not configured.
Source code in skyrl/train/sft_trainer.py:1072-1076
def load_eval_dataset(self) -> Optional[list]:
"""Load and tokenize the eval dataset, or return ``None`` if not configured."""
if not self.sft_cfg.eval_dataset_name:
return None
return self._load_and_tokenize(self.sft_cfg.eval_dataset_name, self.sft_cfg.eval_dataset_split)method collate_batch
collate_batch(examples: list, batch_size: int) -> TrainingInputBatchCollate examples into a TrainingInputBatch via the configured collator.
Delegates to self.collator (DefaultCollator or, when sequence
packing is enabled, PackedDataCollator).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples | list | Tokenized examples to collate. | required |
batch_size | int | Global batch dimension. The train path passes sft_cfg.batch_size and the eval path passes its per-dispatch chunk size. | required |
Source code in skyrl/train/sft_trainer.py:1109-1121
def collate_batch(self, examples: list, batch_size: int) -> TrainingInputBatch:
"""Collate examples into a TrainingInputBatch via the configured collator.
Delegates to ``self.collator`` (``DefaultCollator`` or, when sequence
packing is enabled, ``PackedDataCollator``).
Args:
examples: Tokenized examples to collate.
batch_size: Global batch dimension. The train path passes
``sft_cfg.batch_size`` and the eval path passes its
per-dispatch chunk size.
"""
return self.collator(examples, batch_size=batch_size)method abstractmethod load_checkpoint
load_checkpoint() -> intLoad a checkpoint and return the step number to resume from.
Behaviour depends on sft_cfg.resume_from:
""(empty): no resume, return 0."latest": readlatest_ckpt_global_step.txtfromckpt_path.- otherwise: treat as a direct path to a
global_step_Ndirectory.
Returns:
| Type | Description |
|---|---|
| int | The global step to resume from (0 if no checkpoint loaded). |
Source code in skyrl/train/sft_trainer.py:1127-1200
def load_checkpoint(self) -> int:
"""Load a checkpoint and return the step number to resume from.
Behaviour depends on ``sft_cfg.resume_from``:
- ``""`` (empty): no resume, return 0.
- ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
- otherwise: treat as a direct path to a ``global_step_N`` directory.
Returns:
The global step to resume from (0 if no checkpoint loaded).
"""
resume_from = self.sft_cfg.resume_from
if not resume_from:
return 0
if resume_from == "latest":
if not self.sft_cfg.ckpt_path:
logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
return 0
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_file):
logger.info("No latest checkpoint marker found, starting from scratch")
return 0
with io.open_file(latest_file, "r") as f:
ckpt_step = int(f.read().strip())
checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
# Validate consistency: ensure no stale checkpoint folders from prior runs
validate_consistency_for_latest_checkpoint(
self.sft_cfg.ckpt_path,
ckpt_step,
checkpoint_path,
latest_file,
self.sft_cfg.ckpt_interval,
)
else:
checkpoint_path = resume_from
if not io.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
global_step = extract_step_from_path(checkpoint_path)
if global_step == -1:
raise ValueError(
f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
)
# Load and validate trainer state if available
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
if io.exists(trainer_state_path):
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(
f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
)
else:
logger.warning(
f"No trainer_state.pt found at {trainer_state_path}. "
"This checkpoint was likely saved by an older version."
)
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info(f"Successfully resumed from global_step_{global_step}")
return global_stepmethod run_eval
run_eval(eval_tokenized: list) -> tuple[dict, int]Compute eval loss over the full eval dataset.
Iterates the eval dataset in chunks of micro_train_batch_size_per_gpu * dp_size
(i.e. exactly one micro-batch per DP rank per dispatch call), calls
:meth:WorkerDispatch.forward with loss_fn="cross_entropy" (which
runs the model in eval() mode under no_grad), and aggregates the
per-batch losses into a token-weighted mean.
The aggregated loss is a token-weighted mean of the per-batch losses, which are themselves per-non-pad-token means within each batch. This yields the true per-non-pad-token mean across the eval dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eval_tokenized | list | Pre-tokenized eval dataset (output of :meth:load_eval_dataset). | required |
Returns:
| Type | Description |
|---|---|
| dict | (metrics, num_eval_batches) where metrics contains |
| int | eval_loss and num_eval_batches is bookkeeping for |
| tuple[dict, int] | stdout logging (not a wandb metric). |
Source code in skyrl/train/sft_trainer.py:1206-1282
def run_eval(self, eval_tokenized: list) -> tuple[dict, int]:
"""Compute eval loss over the full eval dataset.
Iterates the eval dataset in chunks of ``micro_train_batch_size_per_gpu * dp_size``
(i.e. exactly one micro-batch per DP rank per dispatch call), calls
:meth:`WorkerDispatch.forward` with ``loss_fn="cross_entropy"`` (which
runs the model in ``eval()`` mode under ``no_grad``), and aggregates the
per-batch losses into a token-weighted mean.
The aggregated loss is a token-weighted mean of the per-batch losses,
which are themselves per-non-pad-token means within each batch. This
yields the true per-non-pad-token mean across the eval dataset.
Args:
eval_tokenized: Pre-tokenized eval dataset (output of
:meth:`load_eval_dataset`).
Returns:
``(metrics, num_eval_batches)`` where ``metrics`` contains
``eval_loss`` and ``num_eval_batches`` is bookkeeping for
stdout logging (not a wandb metric).
"""
num_eval = len(eval_tokenized)
if num_eval == 0:
raise ValueError(
"Eval dataset is empty. Provide a non-empty eval split or disable eval "
"by setting eval_dataset_name=None."
)
# One micro-batch per DP rank per dispatch call — keeps memory usage bounded
# and removes the need for a separate `eval_batch_size` knob.
dp_size = self.dispatch.dp_size("policy")
eval_chunk_size = self.sft_cfg.micro_train_batch_size_per_gpu * dp_size
# Pad a trailing partial batch up to ``eval_chunk_size`` via
# ``pad_training_input_batch`` (which zeros ``loss_mask`` on padded rows).
# Padded rows contribute 0 to the cross-entropy numerator, and the
# pre-padding ``total_nonpad`` scaling in ``collate_batch`` excludes
# them from the denominator, so the reported ``eval_loss`` is the
# per-real-token mean over the full (non-padded) eval set.
num_eval_batches = ceil(num_eval / eval_chunk_size)
total_loss_weighted = 0.0
total_tokens = 0
for batch_idx in range(num_eval_batches):
start = batch_idx * eval_chunk_size
end = min(start + eval_chunk_size, num_eval)
batch_examples = eval_tokenized[start:end]
batch = self.collator(batch_examples, batch_size=eval_chunk_size)
# Pad the last (possibly-short) chunk so every dispatch sees exactly
# ``eval_chunk_size`` rows. ``pad_training_input_batch`` zeros the
# ``loss_mask`` for padding rows; with ``pad_size=0`` it is a no-op.
pad_rows = eval_chunk_size - len(batch_examples)
if pad_rows > 0:
logger.info(
f"Padding final eval batch by {pad_rows} rows "
f"({len(batch_examples)} real -> {eval_chunk_size} total); "
f"padded rows are masked out of the loss."
)
batch = pad_training_input_batch(batch, pad_rows)
# Count non-pad response tokens (from the unscaled mask, recovered from the batch)
# We use the attention_mask response window via collate_sft_batch's loss_mask which
# was 0/1 before scaling. Recover the count from the batch by counting positive entries.
# Padded rows have loss_mask=0 so they are excluded here.
nonpad_tokens = int((batch["loss_mask"] > 0).sum().item())
output = self.dispatch.forward(
"policy",
batch,
loss_fn="cross_entropy",
loss_fn_config=None,
)
batch_loss = float(output.metrics.get("loss", float("nan")))
total_loss_weighted += batch_loss * nonpad_tokens
total_tokens += nonpad_tokens
eval_loss = total_loss_weighted / max(total_tokens, 1)
return {"eval_loss": eval_loss}, num_eval_batchesmethod train_step
train_step(batch: TrainingInputBatch, step: int) -> dictExecute a single training step: forward_backward + optim_step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch | TrainingInputBatch | The collated training batch. | required |
step | int | Current global step (reserved for future use, e.g. scheduling). | required |
Returns:
| Type | Description |
|---|---|
| dict | Dict with loss, grad_norm, and timings. |
Source code in skyrl/train/sft_trainer.py:1284-1306
def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
"""Execute a single training step: forward_backward + optim_step.
Args:
batch: The collated training batch.
step: Current global step (reserved for future use, e.g. scheduling).
Returns:
Dict with ``loss``, ``grad_norm``, and ``timings``.
"""
timings: dict[str, float] = {}
with Timer("forward_backward", timings):
output = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
with Timer("optim_step", timings):
grad_norm = self.dispatch.optim_step("policy")
metrics = output.metrics
loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
return {
"loss": loss_val,
"grad_norm": grad_norm,
"timings": timings,
}method train
train()Full training loop: load data, iterate, log, checkpoint.
Source code in skyrl/train/sft_trainer.py:1423-1694
def train(self):
"""Full training loop: load data, iterate, log, checkpoint."""
if self.sft_cfg.dummy_run_full_ctx:
if self.sft_cfg.resume_from:
logger.warning("resume_from is ignored in dummy run mode")
return self._train_dummy()
tokenized = self.load_dataset()
# Log tokenized sequence length statistics (once, before training loop)
self._log_dataset_stats(tokenized)
# Load eval dataset (if configured). We load once up-front so the
# tokenization cost is amortized across all eval invocations.
eval_tokenized = self.load_eval_dataset()
if eval_tokenized is not None:
logger.info(f"Eval dataset loaded: {len(eval_tokenized)} examples")
batch_size = self.sft_cfg.batch_size
# steps_per_epoch is always derived from the data; callbacks rely on it.
steps_per_epoch = max(1, ceil(len(tokenized) / batch_size))
# Resolve num_steps: explicit num_steps takes precedence; otherwise derive from num_epochs.
if self.sft_cfg.num_steps is not None:
num_steps = self.sft_cfg.num_steps
else:
num_steps = self.sft_cfg.num_epochs * steps_per_epoch
logger.info(
f"num_steps not set; deriving from num_epochs={self.sft_cfg.num_epochs}: "
f"ceil({len(tokenized)} / {batch_size}) * {self.sft_cfg.num_epochs} = {num_steps} steps"
)
# Early validation: dataset must have at least batch_size examples
if len(tokenized) < batch_size:
raise ValueError(
f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
f"Reduce batch_size or use more data."
)
self._validate_batch_parallelism()
# Resume from checkpoint if configured
start_step = self.load_checkpoint()
# Shuffle data before training
rng = random.Random(self.sft_cfg.seed)
rng.shuffle(tokenized)
# When resuming, start_step is the last *completed* step (checkpoint is
# saved AFTER the optimizer update), so we begin at start_step + 1 to
# avoid replaying that step.
# Replay epoch shuffles for reproducibility on resume
start_epoch = (start_step * batch_size) // len(tokenized)
for _ in range(start_epoch):
rng.shuffle(tokenized)
current_epoch = start_epoch
# Initialize `global_step`
self.global_step = start_step
# Publish loop metadata so CallbackInput can be built consistently.
self._total_steps = num_steps
self._steps_per_epoch = steps_per_epoch
self._current_epoch = current_epoch
self._training_control.reset()
logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
if start_step > 0:
logger.info(f"Resuming from step {start_step}")
if self._ray_gpu_monitor is not None:
self._ray_gpu_monitor.start()
# Tracks whether the most recent in-loop iteration saved a checkpoint
# (either via the ckpt_interval or via a callback-driven ``should_save``).
did_save_last_step = False
self._fire("on_train_start")
# Baseline eval before training begins (logged at step 0).
# Wandb's step counter starts at 0; the training loop's first commit
# advances it to >=1, so step=0 here does not conflict with later steps.
if self.sft_cfg.eval_before_train and eval_tokenized is not None:
self._fire("on_eval_start")
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
self._fire("on_eval_end", metrics=eval_metrics)
baseline_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
self._fire("on_log", logs=baseline_log)
self.tracker.log(baseline_log, step=self.global_step, commit=True)
logger.info(
f"Baseline eval before training: "
f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
f"over {num_eval_batches} batches"
)
# SkyRL starts counting at step 1
self.global_step = start_step + 1 if start_step > 0 else 1
self._fire("on_epoch_start")
epoch_in_progress = True
while self.global_step <= num_steps:
all_timings: dict[str, float] = {}
with Timer("step", all_timings):
# Data loading with wrap-around
with Timer("data_loading", all_timings):
start_idx = (self.global_step * batch_size) % len(tokenized)
end_idx = start_idx + batch_size
if end_idx > len(tokenized):
batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
else:
batch_examples = tokenized[start_idx:end_idx]
batch = self.collator(batch_examples, batch_size=batch_size)
self._fire("on_step_start", batch=batch)
# Training step
step_result = self.train_step(batch, self.global_step)
all_timings.update(step_result["timings"])
# Compute throughput using actual (non-padding) tokens
batch_padded_seq_len = batch["sequences"].shape[1]
actual_num_tokens = batch["attention_mask"].sum().item()
self._total_tokens_processed += actual_num_tokens
tokens_per_second = actual_num_tokens / all_timings["step"]
# Build log dict
log_dict = {
"train/loss": step_result["loss"],
"train/grad_norm": step_result["grad_norm"],
"train/tokens_per_second": tokens_per_second,
"train/tokens_per_second_per_gpu": tokens_per_second / self._num_training_gpus,
"train/actual_num_tokens": actual_num_tokens,
"train/batch_padded_seq_len": batch_padded_seq_len,
"train/total_tokens_processed": self._total_tokens_processed,
}
log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
if self._ray_gpu_monitor is not None:
log_dict.update(self._ray_gpu_monitor.flush())
self._fire("on_step_end", batch=batch, metrics=step_result)
# Capture callback-driven triggers, then reset so they only fire once.
force_save = self._training_control.should_save
force_eval = self._training_control.should_evaluate
self._training_control.should_save = False
self._training_control.should_evaluate = False
# Checkpoint: interval-driven or callback-requested.
interval_save = (
self.sft_cfg.ckpt_interval > 0
and self.global_step > 0
and self.global_step % self.sft_cfg.ckpt_interval == 0
)
did_save_last_step = force_save or interval_save
if did_save_last_step:
with Timer("save_checkpoint", all_timings):
ckpt_path = self.save_checkpoint()
log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]
self._fire("on_save", ckpt_path=ckpt_path)
# HF export at regular intervals
if self.sft_cfg.hf_save_interval > 0 and self.global_step % self.sft_cfg.hf_save_interval == 0:
with Timer("save_hf_model", all_timings):
self.save_hf_model()
log_dict["timing/save_hf_model"] = all_timings["save_hf_model"]
eval_metrics = None
num_eval_batches: int | None = None
# Eval fires at step N where N % eval_interval == 0 and N > 0, OR
# whenever a callback set ``control.should_evaluate``.
interval_eval = self.sft_cfg.eval_interval > 0 and self.global_step % self.sft_cfg.eval_interval == 0
if eval_tokenized is not None and (force_eval or interval_eval):
self._fire("on_eval_start")
with Timer("eval", all_timings):
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
self._fire("on_eval_end", metrics=eval_metrics)
if eval_metrics:
log_dict.update({f"eval/{k}": v for k, v in eval_metrics.items()})
log_dict["timing/eval"] = all_timings["eval"]
log_dict.update({"train/epoch": current_epoch, "train/global_step": self.global_step})
# Callbacks may mutate log_dict in place via on_log.
self._fire("on_log", logs=log_dict)
self.tracker.log(log_dict, step=self.global_step, commit=True)
if self.global_step % 5 == 0:
logger.info(
f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
)
if eval_metrics:
logger.info(
f"Step {self.global_step}: eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
f"over {num_eval_batches} batches"
)
# Check for epoch boundary and reshuffle
epoch = (self.global_step * batch_size) // len(tokenized)
if epoch > current_epoch:
self._fire("on_epoch_end")
epoch_in_progress = False
for _ in range(epoch - current_epoch):
rng.shuffle(tokenized)
current_epoch = epoch
self._current_epoch = epoch
if self.global_step + 1 <= num_steps:
self._fire("on_epoch_start")
epoch_in_progress = True
self.global_step += 1
self.global_step = min(self.global_step, num_steps)
# Pair the leading on_epoch_start: fire on_epoch_end if we exited the
# loop mid-epoch
if epoch_in_progress:
self._fire("on_epoch_end")
epoch_in_progress = False
# Save final checkpoint (if checkpointing is enabled). Skip if the last
# in-loop iteration already saved (either via ckpt_interval or via a
# callback-driven force-save) so we don't double-save.
if self.sft_cfg.ckpt_path and not did_save_last_step:
final_step = num_steps
logger.info(f"Saving final checkpoint at step {final_step}")
ckpt_path = self.save_checkpoint()
self._fire("on_save", ckpt_path=ckpt_path)
# Save final HF model if enabled (only if not already saved at last step)
if self.sft_cfg.hf_save_interval > 0:
final_step = num_steps
already_saved = final_step % self.sft_cfg.hf_save_interval == 0
if not already_saved:
self.global_step = final_step
logger.info(f"Saving final HF model at step {final_step}")
self.save_hf_model()
# Final eval pass (skip if the last step already ran eval).
# NOTE: The last in-loop tracker.log(..., commit=True) at step=num_steps
# advanced wandb's internal step counter to num_steps+1. Logging the
# final eval at step=num_steps would be rejected by wandb with
# "step N < current step N+1". We log the final eval at num_steps+1
# (one past the last committed train step) in a single combined
# tracker.log() call, preserving wandb step ordering. We use a local
# ``final_eval_step`` rather than mutating ``self.global_step``: the
# bump is purely a wandb-step accounting concern, not real trainer
# state.
if eval_tokenized is not None:
already_ran = self.sft_cfg.eval_interval > 0 and num_steps % self.sft_cfg.eval_interval == 0
if not already_ran:
final_eval_step = num_steps + 1
eval_timings: dict[str, float] = {}
self._fire("on_eval_start")
with Timer("eval", eval_timings):
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
self._fire("on_eval_end", metrics=eval_metrics)
if eval_metrics:
eval_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
eval_log["timing/eval"] = eval_timings["eval"]
self._fire("on_log", logs=eval_log)
self.tracker.log(eval_log, step=final_eval_step, commit=True)
logger.info(
f"Final eval at step {final_eval_step}: "
f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
f"over {num_eval_batches} batches"
)
self._fire("on_train_end")
logger.info("SFT training complete!")method abstractmethod save_checkpoint
save_checkpoint() -> strSave a checkpoint at the given step. Returns the checkpoint folder path.
Source code in skyrl/train/sft_trainer.py:1696-1723
def save_checkpoint(self) -> str:
"""Save a checkpoint at the given step. Returns the checkpoint folder path."""
step = self.global_step
global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
io.makedirs(global_step_folder, exist_ok=True)
logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
trainer_state = {
"global_step": step,
"config": asdict(self.sft_cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking -- write this last after all saves succeed
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_file, "w") as f:
f.write(str(step))
logger.info(f"Checkpoint saved for global_step_{step}")
# Clean up old checkpoints after successful save
cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)
return global_step_foldermethod save_hf_model
save_hf_model()Save policy weights in HuggingFace format.
Export path: cfg.trainer.export_path/global_step_{step}/policy Mirrors the pattern used by the RL trainer's save_models().
Source code in skyrl/train/sft_trainer.py:1725-1738
def save_hf_model(self):
"""Save policy weights in HuggingFace format.
Export path: cfg.trainer.export_path/global_step_{step}/policy
Mirrors the pattern used by the RL trainer's save_models().
"""
step = self.global_step
policy_export_dir = os.path.join(
self.cfg.trainer.export_path,
f"{GLOBAL_STEP_PREFIX}{step}",
"policy",
)
self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
logger.info(f"Saved HF model weights at step {step} to {policy_export_dir}")method shutdown
shutdown()Finish tracking.
Does NOT call ray.shutdown() -- when running inside a Ray task
(the normal path via sft_entrypoint), shutting down Ray from
within the task would be incorrect. The head-node process owns
the Ray lifecycle.
Source code in skyrl/train/sft_trainer.py:1744-1755
def shutdown(self):
"""Finish tracking.
Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
(the normal path via ``sft_entrypoint``), shutting down Ray from
within the task would be incorrect. The head-node process owns
the Ray lifecycle.
"""
if self._ray_gpu_monitor is not None:
self._ray_gpu_monitor.stop()
if self.tracker is not None:
self.tracker.finish()Config Bridge
method validate_sft_cfg
validate_sft_cfg(cfg: SFTConfig) -> NoneValidate SFT-specific configuration.
Only checks fields that are relevant to SFT training, unlike
validate_cfg which includes RL-specific validations.
method build_skyrl_config_for_sft
build_skyrl_config_for_sft(sft_cfg: SFTConfig) -> SkyRLTrainConfigMap user-facing SFTConfig to the internal SkyRL backend config.