SkyRL
API ReferenceSkyRL

SFT

Supervised Fine-Tuning configuration and trainer.

Configuration

class SFTPlacementConfig

SFTPlacementConfig(num_nodes: int = 1, num_gpus_per_node: int = 4) -> None

Bases: BaseConfig

Placement configuration for SFT training

Functions:

NameDescription
from_dict_configConstruct a typed BaseConfig from a Hydra DictConfig.

Attributes:

NameTypeDescription
num_nodesint
num_gpus_per_nodeint
Source code in skyrl/train/config/sft_config.py:48-53
@dataclass
class SFTPlacementConfig(BaseConfig):
    """Placement configuration for SFT training"""

    num_nodes: int = 1
    num_gpus_per_node: int = 4

from_dict_config

from_dict_config(cfg: DictConfig) -> BaseConfig

Construct a typed BaseConfig from a Hydra DictConfig.

attr num_nodes

num_nodes: int = 1

attr num_gpus_per_node

num_gpus_per_node: int = 4

class SFTConfig

SFTConfig(model: ModelConfig = (lambda: ModelConfig(path='Qwen/Qwen3-0.6B'))(), optimizer_config: OptimizerConfig = OptimizerConfig(), placement: SFTPlacementConfig = SFTPlacementConfig(), megatron_config: MegatronConfig = (lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2))(), fsdp_config: FSDPConfig = FSDPConfig(), sequence_parallel_size: int = 1, model_config_kwargs: dict = dict(), use_torch_compile: bool = False, record_memory: bool = False, strategy: str = 'megatron', dataset_name: str = 'yahma/alpaca-cleaned', dataset_split: str = 'train[:100]', messages_key: str = 'messages', tools_key: str = 'tools', system_key: str = 'system', eval_dataset_name: Optional[str] = None, eval_dataset_split: str = 'validation', eval_interval: int = 0, eval_before_train: bool = False, max_length: Optional[int] = None, num_steps: Optional[int] = None, num_epochs: Optional[int] = 1, batch_size: int = 4, micro_train_batch_size_per_gpu: int = 2, logger: str = 'console', project_name: str = 'skyrl_sft', run_name: str = 'skyrl_sft_run', tags: Optional[List[str]] = None, ckpt_path: str = '', ckpt_interval: int = 0, enable_ray_gpu_monitor: bool = True, max_ckpts_to_keep: int = -1, resume_from: str = '', hf_save_interval: int = 0, export_path: str = '', seed: int = 42, num_workers: int = 8, cache_dir: str = os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache')), 'skyrl', 'tokenized_datasets'), force_recache: bool = False, disable_cache: bool = False, train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, remove_microbatch_padding: bool = True, use_sequence_packing: bool = False, max_tokens_per_microbatch: Optional[int] = None, dummy_run_full_ctx: bool = False, dummy_run_max_steps: int = 5) -> None

Bases: BaseConfig

Configuration for SFT training.

Usage::

cfg = SFTConfig( strategy="megatron", placement=SFTPlacementConfig(num_gpus_per_node=4), megatron_config=MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2), )

Or from CLI::

cfg = SFTConfig.from_cli_overrides(sys.argv[1:])

Functions:

NameDescription
from_dict_configConstruct a typed BaseConfig from a Hydra DictConfig.
from_cli_overridesConstruct an SFTConfig from CLI arguments or a dict of overrides.
resolved_bin_capacityFFD bin capacity (max tokens per bin) when sequence packing is enabled.

Attributes:

NameTypeDescription
modelModelConfig
optimizer_configOptimizerConfig
placementSFTPlacementConfig
megatron_configMegatronConfig
fsdp_configFSDPConfig
sequence_parallel_sizeintUlysses sequence parallelism size
model_config_kwargsdictPass-through kwargs for the HuggingFace model config (FSDP backends).
use_torch_compileboolApply torch.compile to logits calculation.
record_memoryboolSave memory snapshots to {ckpt_path}/memory_snapshots/.
strategystr
dataset_namestr
dataset_splitstr
messages_keystr
tools_keystrColumn name holding per-row tool/function schemas for tool-calling datasets
system_keystrColumn name holding a per-row system prompt to prepend when messages
eval_dataset_nameOptional[str]HuggingFace dataset name (or path) used to compute eval loss during training.
eval_dataset_splitstrSplit of the eval dataset to load (e.g. "validation", "test[:500]").
eval_intervalintRun eval every N training steps. Eval also runs once at the end of training
eval_before_trainboolIf True, run a baseline eval pass before training begins (logged at step 0).
max_lengthOptional[int]Maximum length of tokenized sequences. If specified, all sequences will be truncated to this value
num_stepsOptional[int]Number of training steps. If None, num_epochs is used to derive the step count.
num_epochsOptional[int]Number of training epochs. Used when num_steps is None. Default: 1 epoch.
batch_sizeint
micro_train_batch_size_per_gpuint
loggerstr
project_namestr
run_namestr
tagsOptional[List[str]]Optional list of tags to apply to the W&B run. Has no effect on other backends.
ckpt_pathstr
ckpt_intervalint
enable_ray_gpu_monitorboolEnable background Ray GPU/RAM metrics collection and logging to wandb.
max_ckpts_to_keepint-1 to keep all checkpoints, N to keep only the last N.
resume_fromstr
hf_save_intervalintSave HuggingFace-format weights every N steps. 0 = disabled.
export_pathstrDirectory for HF-format exports. Defaults to ckpt_path/hf_exports if empty.
seedint
num_workersintNumber of worker processes for parallel tokenization during dataset loading. Set to 0 for single-threaded.
cache_dirstrDirectory to cache tokenized datasets. For multi-node training, set this to an NFS-mounted path so all nodes can
force_recacheboolIf True, ignore existing cache and re-tokenize the dataset.
disable_cacheboolIf True, disable cache completely (always tokenize from scratch).
train_on_whatTrainOnWhatWhich tokens to compute loss on. See :class:TrainOnWhat for options.
remove_microbatch_paddingbool
use_sequence_packingboolEnable controller-level FFD bin-packing across the global mini-batch.
max_tokens_per_microbatchOptional[int]FFD bin capacity (max tokens per bin) when use_sequence_packing=True.
dummy_run_full_ctxbool
dummy_run_max_stepsint
Source code in skyrl/train/config/sft_config.py:56-256
@dataclass
class SFTConfig(BaseConfig):
    """Configuration for SFT training.

    Usage::

        cfg = SFTConfig(
            strategy="megatron",
            placement=SFTPlacementConfig(num_gpus_per_node=4),
            megatron_config=MegatronConfig(tensor_model_parallel_size=2,
                                    pipeline_model_parallel_size=2),
        )

    Or from CLI::

        cfg = SFTConfig.from_cli_overrides(sys.argv[1:])
    """

    @classmethod
    def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
        """Construct an SFTConfig from CLI arguments or a dict of overrides.

        Parses CLI dotlist arguments via OmegaConf and builds a typed config.
        Dataclass field defaults are used for any values not specified.

        Args:
            args: Either a list of CLI arguments in 'key.path=value' format, or a dict
                  mapping dot-notation keys to values.
                  Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
                  Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}

        Returns:
            A fully constructed SFTConfig with CLI overrides applied.

        Raises:
            ValueError: If both ``num_epochs`` and ``num_steps`` are explicitly provided.
        """
        if isinstance(args, dict):
            args = [f"{k}={v}" for k, v in args.items()]

        overrides = OmegaConf.from_cli(args)
        # Check for mutual exclusion before constructing the full config
        if "num_epochs" in overrides and "num_steps" in overrides:
            raise ValueError("Cannot specify both num_epochs and num_steps")
        # Accept the deprecated ``use_sample_packing`` key as an alias for
        # ``remove_microbatch_padding``. Remap it before construction so the
        # strict key validation does not reject the old name.
        if "use_sample_packing" in overrides:
            if "remove_microbatch_padding" in overrides:
                raise ValueError(
                    "Specify only one of use_sample_packing (deprecated) and remove_microbatch_padding, not both."
                )
            import warnings

            warnings.warn(
                "use_sample_packing has been renamed to remove_microbatch_padding; "
                "use remove_microbatch_padding instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            overrides["remove_microbatch_padding"] = overrides["use_sample_packing"]
            del overrides["use_sample_packing"]
        return cls.from_dict_config(overrides)

    # ---- Reused SkyRL config objects ----
    model: ModelConfig = field(default_factory=lambda: ModelConfig(path="Qwen/Qwen3-0.6B"))
    optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
    placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)
    megatron_config: MegatronConfig = field(
        default_factory=lambda: MegatronConfig(
            tensor_model_parallel_size=2,
            pipeline_model_parallel_size=2,
        )
    )
    fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)

    # Ulysses sequence parallelism
    sequence_parallel_size: int = 1
    """Ulysses sequence parallelism size"""

    model_config_kwargs: dict = field(default_factory=dict)
    """Pass-through kwargs for the HuggingFace model config (FSDP backends).
    For Megatron, use ``megatron_config.transformer_config_kwargs`` instead."""
    use_torch_compile: bool = False
    """Apply torch.compile to logits calculation."""
    record_memory: bool = False
    """Save memory snapshots to ``{ckpt_path}/memory_snapshots/``.
    Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz."""

    # ---- SFT-specific flat fields ----
    strategy: str = "megatron"  # "megatron" or "fsdp"
    dataset_name: str = "yahma/alpaca-cleaned"
    dataset_split: str = "train[:100]"
    messages_key: str = "messages"  # column name for chat-format datasets
    tools_key: str = "tools"
    """Column name holding per-row tool/function schemas for tool-calling datasets
    (e.g. APIGen-MT, xLAM, ToolACE). May be a list[dict] or a JSON-encoded string.
    Ignored if the column is absent from the dataset."""
    system_key: str = "system"
    """Column name holding a per-row system prompt to prepend when ``messages``
    does not already start with a system turn. Ignored if absent."""

    # ---- Evaluation dataset ----
    eval_dataset_name: Optional[str] = None
    """HuggingFace dataset name (or path) used to compute eval loss during training.
    When ``None`` (default), eval is disabled."""
    eval_dataset_split: str = "validation"
    """Split of the eval dataset to load (e.g. ``"validation"``, ``"test[:500]"``)."""
    eval_interval: int = 0
    """Run eval every N training steps. Eval also runs once at the end of training
    when an eval dataset is configured. ``0`` disables periodic eval."""
    eval_before_train: bool = False
    """If True, run a baseline eval pass before training begins (logged at step 0)."""
    max_length: Optional[int] = None
    """Maximum length of tokenized sequences. If specified, all sequences will be truncated to this value
    By default, no truncation is performed"""
    num_steps: Optional[int] = None
    """Number of training steps. If None, num_epochs is used to derive the step count."""
    num_epochs: Optional[int] = 1
    """Number of training epochs. Used when num_steps is None. Default: 1 epoch."""
    batch_size: int = 4
    micro_train_batch_size_per_gpu: int = 2
    logger: str = "console"  # "console" or "wandb"
    project_name: str = "skyrl_sft"
    run_name: str = "skyrl_sft_run"
    tags: Optional[List[str]] = None
    """Optional list of tags to apply to the W&B run. Has no effect on other backends."""
    ckpt_path: str = ""
    ckpt_interval: int = 0  # <= 0 -> no checkpointing
    enable_ray_gpu_monitor: bool = True
    """Enable background Ray GPU/RAM metrics collection and logging to wandb."""
    max_ckpts_to_keep: int = -1
    """-1 to keep all checkpoints, N to keep only the last N."""
    resume_from: str = ""  # "" = no resume, "latest" = latest checkpoint, or path to global_step_N dir

    # ---- HF export ----
    hf_save_interval: int = 0
    """Save HuggingFace-format weights every N steps. 0 = disabled."""
    export_path: str = ""
    """Directory for HF-format exports. Defaults to ckpt_path/hf_exports if empty."""

    seed: int = 42

    # ---- Data loading ----
    num_workers: int = 8
    """Number of worker processes for parallel tokenization during dataset loading. Set to 0 for single-threaded."""

    # ---- Tokenized dataset caching ----
    cache_dir: str = os.path.join(
        os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")), "skyrl", "tokenized_datasets"
    )
    """Directory to cache tokenized datasets. For multi-node training, set this to an NFS-mounted path so all nodes can
    share the cache."""
    force_recache: bool = False
    """If True, ignore existing cache and re-tokenize the dataset."""
    disable_cache: bool = False
    """If True, disable cache completely (always tokenize from scratch)."""

    # ---- Training target ----
    train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE
    """Which tokens to compute loss on. See :class:`TrainOnWhat` for options."""

    # ---- Packing ----
    remove_microbatch_padding: bool = True  # Pack multiple sequences per microbatch (requires flash_attn)
    use_sequence_packing: bool = False
    """Enable controller-level FFD bin-packing across the global mini-batch.
    Requires ``remove_microbatch_padding=True`` and the Megatron backend. When
    enabled, ``SFTTrainer`` uses ``PackedDataCollator`` instead of
    ``DefaultCollator``. Each bin row becomes one row in the dispatched batch
    and one worker micro-batch.
    """
    max_tokens_per_microbatch: Optional[int] = None
    """FFD bin capacity (max tokens per bin) when ``use_sequence_packing=True``.
    Each bin row becomes one worker micro-batch, so this is the token budget for
    one micro-batch. Must be ``>= max_length`` so any single sequence fits in a
    bin. ``None`` (default) resolves to ``max_length`` (each bin holds one
    sequence)."""

    # ---- Dummy run / benchmarking ----
    dummy_run_full_ctx: bool = False  # Skip real data; fabricate full-context sequences
    dummy_run_max_steps: int = 5  # Number of steps to run in dummy mode

    def resolved_bin_capacity(self) -> int:
        """FFD bin capacity (max tokens per bin) when sequence packing is enabled.

        Resolves ``max_tokens_per_microbatch`` against ``max_length``: when the
        token budget is ``None`` it falls back to ``max_length`` (each bin holds
        one sequence). Requires ``max_length`` to be set and the resolved budget
        to be ``>= max_length`` so any single sequence fits in a bin.
        """
        if self.max_length is None:
            raise ValueError("max_tokens_per_microbatch requires max_length to be set.")
        max_tokens = self.max_tokens_per_microbatch
        if max_tokens is None:
            max_tokens = self.max_length
        if max_tokens < self.max_length:
            raise ValueError(
                f"max_tokens_per_microbatch ({max_tokens}) must be >= max_length "
                f"({self.max_length}) so any single sequence fits in a bin."
            )
        return max_tokens

from_dict_config

from_dict_config(cfg: DictConfig) -> BaseConfig

Construct a typed BaseConfig from a Hydra DictConfig.

method classmethod from_cli_overrides

from_cli_overrides(args: Union[List[str], dict]) -> SFTConfig

Construct an SFTConfig from CLI arguments or a dict of overrides.

Parses CLI dotlist arguments via OmegaConf and builds a typed config. Dataclass field defaults are used for any values not specified.

Parameters:

NameTypeDescriptionDefault
argsUnion[List[str], dict]Either a list of CLI arguments in 'key.path=value' format, or a dict mapping dot-notation keys to values. Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B'] Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}required

Returns:

TypeDescription
SFTConfigA fully constructed SFTConfig with CLI overrides applied.

Raises:

TypeDescription
ValueErrorIf both num_epochs and num_steps are explicitly provided.
Source code in skyrl/train/config/sft_config.py:74-118
    @classmethod
    def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
        """Construct an SFTConfig from CLI arguments or a dict of overrides.

        Parses CLI dotlist arguments via OmegaConf and builds a typed config.
        Dataclass field defaults are used for any values not specified.

        Args:
            args: Either a list of CLI arguments in 'key.path=value' format, or a dict
                  mapping dot-notation keys to values.
                  Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
                  Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}

        Returns:
            A fully constructed SFTConfig with CLI overrides applied.

        Raises:
            ValueError: If both ``num_epochs`` and ``num_steps`` are explicitly provided.
        """
        if isinstance(args, dict):
            args = [f"{k}={v}" for k, v in args.items()]

        overrides = OmegaConf.from_cli(args)
        # Check for mutual exclusion before constructing the full config
        if "num_epochs" in overrides and "num_steps" in overrides:
            raise ValueError("Cannot specify both num_epochs and num_steps")
        # Accept the deprecated ``use_sample_packing`` key as an alias for
        # ``remove_microbatch_padding``. Remap it before construction so the
        # strict key validation does not reject the old name.
        if "use_sample_packing" in overrides:
            if "remove_microbatch_padding" in overrides:
                raise ValueError(
                    "Specify only one of use_sample_packing (deprecated) and remove_microbatch_padding, not both."
                )
            import warnings

            warnings.warn(
                "use_sample_packing has been renamed to remove_microbatch_padding; "
                "use remove_microbatch_padding instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            overrides["remove_microbatch_padding"] = overrides["use_sample_packing"]
            del overrides["use_sample_packing"]
        return cls.from_dict_config(overrides)

attr model

model: ModelConfig = field(default_factory=(lambda: ModelConfig(path='Qwen/Qwen3-0.6B')))

attr optimizer_config

optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)

attr placement

placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)

attr megatron_config

megatron_config: MegatronConfig = field(default_factory=(lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2)))

attr fsdp_config

fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)

attr sequence_parallel_size

sequence_parallel_size: int = 1

Ulysses sequence parallelism size

attr model_config_kwargs

model_config_kwargs: dict = field(default_factory=dict)

Pass-through kwargs for the HuggingFace model config (FSDP backends). For Megatron, use megatron_config.transformer_config_kwargs instead.

attr use_torch_compile

use_torch_compile: bool = False

Apply torch.compile to logits calculation.

attr record_memory

record_memory: bool = False

Save memory snapshots to {ckpt_path}/memory_snapshots/. Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz.

attr strategy

strategy: str = 'megatron'

attr dataset_name

dataset_name: str = 'yahma/alpaca-cleaned'

attr dataset_split

dataset_split: str = 'train[:100]'

attr messages_key

messages_key: str = 'messages'

attr tools_key

tools_key: str = 'tools'

Column name holding per-row tool/function schemas for tool-calling datasets (e.g. APIGen-MT, xLAM, ToolACE). May be a list[dict] or a JSON-encoded string. Ignored if the column is absent from the dataset.

attr system_key

system_key: str = 'system'

Column name holding a per-row system prompt to prepend when messages does not already start with a system turn. Ignored if absent.

attr eval_dataset_name

eval_dataset_name: Optional[str] = None

HuggingFace dataset name (or path) used to compute eval loss during training. When None (default), eval is disabled.

attr eval_dataset_split

eval_dataset_split: str = 'validation'

Split of the eval dataset to load (e.g. "validation", "test[:500]").

attr eval_interval

eval_interval: int = 0

Run eval every N training steps. Eval also runs once at the end of training when an eval dataset is configured. 0 disables periodic eval.

attr eval_before_train

eval_before_train: bool = False

If True, run a baseline eval pass before training begins (logged at step 0).

attr max_length

max_length: Optional[int] = None

Maximum length of tokenized sequences. If specified, all sequences will be truncated to this value By default, no truncation is performed

attr num_steps

num_steps: Optional[int] = None

Number of training steps. If None, num_epochs is used to derive the step count.

attr num_epochs

num_epochs: Optional[int] = 1

Number of training epochs. Used when num_steps is None. Default: 1 epoch.

attr batch_size

batch_size: int = 4

attr micro_train_batch_size_per_gpu

micro_train_batch_size_per_gpu: int = 2

attr logger

logger: str = 'console'

attr project_name

project_name: str = 'skyrl_sft'

attr run_name

run_name: str = 'skyrl_sft_run'

attr tags

tags: Optional[List[str]] = None

Optional list of tags to apply to the W&B run. Has no effect on other backends.

attr ckpt_path

ckpt_path: str = ''

attr ckpt_interval

ckpt_interval: int = 0

attr enable_ray_gpu_monitor

enable_ray_gpu_monitor: bool = True

Enable background Ray GPU/RAM metrics collection and logging to wandb.

attr max_ckpts_to_keep

max_ckpts_to_keep: int = -1

-1 to keep all checkpoints, N to keep only the last N.

attr resume_from

resume_from: str = ''

attr hf_save_interval

hf_save_interval: int = 0

Save HuggingFace-format weights every N steps. 0 = disabled.

attr export_path

export_path: str = ''

Directory for HF-format exports. Defaults to ckpt_path/hf_exports if empty.

attr seed

seed: int = 42

attr num_workers

num_workers: int = 8

Number of worker processes for parallel tokenization during dataset loading. Set to 0 for single-threaded.

attr cache_dir

cache_dir: str = os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache')), 'skyrl', 'tokenized_datasets')

Directory to cache tokenized datasets. For multi-node training, set this to an NFS-mounted path so all nodes can share the cache.

attr force_recache

force_recache: bool = False

If True, ignore existing cache and re-tokenize the dataset.

attr disable_cache

disable_cache: bool = False

If True, disable cache completely (always tokenize from scratch).

attr train_on_what

train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE

Which tokens to compute loss on. See :class:TrainOnWhat for options.

attr remove_microbatch_padding

remove_microbatch_padding: bool = True

attr use_sequence_packing

use_sequence_packing: bool = False

Enable controller-level FFD bin-packing across the global mini-batch. Requires remove_microbatch_padding=True and the Megatron backend. When enabled, SFTTrainer uses PackedDataCollator instead of DefaultCollator. Each bin row becomes one row in the dispatched batch and one worker micro-batch.

attr max_tokens_per_microbatch

max_tokens_per_microbatch: Optional[int] = None

FFD bin capacity (max tokens per bin) when use_sequence_packing=True. Each bin row becomes one worker micro-batch, so this is the token budget for one micro-batch. Must be >= max_length so any single sequence fits in a bin. None (default) resolves to max_length (each bin holds one sequence).

attr dummy_run_full_ctx

dummy_run_full_ctx: bool = False

attr dummy_run_max_steps

dummy_run_max_steps: int = 5

method resolved_bin_capacity

resolved_bin_capacity() -> int

FFD bin capacity (max tokens per bin) when sequence packing is enabled.

Resolves max_tokens_per_microbatch against max_length: when the token budget is None it falls back to max_length (each bin holds one sequence). Requires max_length to be set and the resolved budget to be >= max_length so any single sequence fits in a bin.

Source code in skyrl/train/config/sft_config.py:238-256
    def resolved_bin_capacity(self) -> int:
        """FFD bin capacity (max tokens per bin) when sequence packing is enabled.

        Resolves ``max_tokens_per_microbatch`` against ``max_length``: when the
        token budget is ``None`` it falls back to ``max_length`` (each bin holds
        one sequence). Requires ``max_length`` to be set and the resolved budget
        to be ``>= max_length`` so any single sequence fits in a bin.
        """
        if self.max_length is None:
            raise ValueError("max_tokens_per_microbatch requires max_length to be set.")
        max_tokens = self.max_tokens_per_microbatch
        if max_tokens is None:
            max_tokens = self.max_length
        if max_tokens < self.max_length:
            raise ValueError(
                f"max_tokens_per_microbatch ({max_tokens}) must be >= max_length "
                f"({self.max_length}) so any single sequence fits in a bin."
            )
        return max_tokens

Trainer

class SFTTrainer

SFTTrainer(cfg: SFTConfig, skyrl_cfg: SkyRLTrainConfig | None = None, callbacks: Optional[list[TrainingCallback]] = None)

SFT trainer supporting FSDP and Megatron backends.

Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are fundamentally different: no generation, no critic, no advantages, no KL penalty. Sharing a base class would create confusing dead code paths.

Usage::

trainer = SFTTrainer(SFTConfig(strategy="megatron")) trainer.setup() trainer.train() trainer.shutdown()

Functions:

NameDescription
setupInitialize tokenizer, workers, dispatch, and tracker.
add_callbackRegister a callback. Can be called anytime; events fired after this
load_datasetLoad and tokenize the training dataset.
load_eval_datasetLoad and tokenize the eval dataset, or return None if not configured.
collate_batchCollate examples into a TrainingInputBatch via the configured collator.
load_checkpointLoad a checkpoint and return the step number to resume from.
run_evalCompute eval loss over the full eval dataset.
train_stepExecute a single training step: forward_backward + optim_step.
trainFull training loop: load data, iterate, log, checkpoint.
save_checkpointSave a checkpoint at the given step. Returns the checkpoint folder path.
save_hf_modelSave policy weights in HuggingFace format.
shutdownFinish tracking.

Attributes:

Source code in skyrl/train/sft_trainer.py:675-1755
class SFTTrainer:
    """SFT trainer supporting FSDP and Megatron backends.

    Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are
    fundamentally different: no generation, no critic, no advantages, no
    KL penalty. Sharing a base class would create confusing dead code paths.

    Usage::

        trainer = SFTTrainer(SFTConfig(strategy="megatron"))
        trainer.setup()
        trainer.train()
        trainer.shutdown()
    """

    def __init__(
        self,
        cfg: SFTConfig,
        skyrl_cfg: SkyRLTrainConfig | None = None,
        callbacks: Optional[list[TrainingCallback]] = None,
    ):
        self.sft_cfg = cfg
        # Accept a pre-built bridge config to avoid redundant rebuilds.
        # When not provided (e.g. standalone usage), build it here.
        self.cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)
        self.tokenizer = None
        self.dispatch: WorkerDispatch | None = None
        self.tracker: Tracking | None = None
        self.global_step = 0
        # running count of total non-padding tokens trained on
        self._total_tokens_processed = 0
        self.collator = None  # built in setup() once the tokenizer is available

        self._num_training_gpus: int = cfg.placement.num_nodes * cfg.placement.num_gpus_per_node
        self._ray_gpu_monitor = RayGpuMonitor() if cfg.enable_ray_gpu_monitor else None

        self._callback_handler = CallbackHandler(callbacks)
        self._training_control = TrainingControl()
        # Loop metadata used to build CallbackInput. Populated in train().
        self._total_steps: int = 0
        self._steps_per_epoch: int = 0
        self._current_epoch: int = 0

    def _build_collator(self, tokenizer):
        """Select the batch collator from the configured packing mode.

        ``PackedDataCollator`` performs controller-level FFD bin-packing
        (Megatron-only, ``use_sequence_packing=True``); ``DefaultCollator``
        left-pads each example. The choice is fixed by static config; the
        ``tokenizer`` is passed in by :meth:`setup` once it is available. The
        packed config is validated here.
        """
        # Imported lazily to avoid a circular import: ``collators`` imports
        # ``collate_sft_batch`` from this module.
        from skyrl.train.dataset.collators import DefaultCollator, PackedDataCollator

        if self.sft_cfg.use_sequence_packing:
            self._validate_packing_cfg()
            return PackedDataCollator(
                tokenizer=tokenizer,
                max_tokens_per_microbatch=self.sft_cfg.resolved_bin_capacity(),
                tp_size=self.sft_cfg.megatron_config.tensor_model_parallel_size,
                pp_size=self.sft_cfg.megatron_config.pipeline_model_parallel_size,
                cp_size=self.sft_cfg.megatron_config.context_parallel_size,
                dp_size=self._dp_size(),
                batch_size=self.sft_cfg.batch_size,
                micro_train_batch_size_per_gpu=self.sft_cfg.micro_train_batch_size_per_gpu,
            )
        return DefaultCollator(
            tokenizer=tokenizer,
            micro_train_batch_size_per_gpu=self.sft_cfg.micro_train_batch_size_per_gpu,
        )

    def _dp_size(self) -> int:
        """Number of DP ranks under the configured Megatron parallelism."""
        total_gpus = self.sft_cfg.placement.num_nodes * self.sft_cfg.placement.num_gpus_per_node
        tp = self.sft_cfg.megatron_config.tensor_model_parallel_size
        pp = self.sft_cfg.megatron_config.pipeline_model_parallel_size
        cp = self.sft_cfg.megatron_config.context_parallel_size
        return total_gpus // (tp * pp * cp)

    def _validate_packing_cfg(self):
        """Validate the config when ``use_sequence_packing=True``."""
        if self.sft_cfg.strategy != "megatron":
            raise ValueError(
                f"use_sequence_packing=True only supports strategy='megatron'; got "
                f"{self.sft_cfg.strategy!r}. Use the FSDP packing path instead."
            )
        # Sequence packing needs the THD layout, so it implies
        # remove_microbatch_padding=True. Auto-enable it (warning if the user
        # explicitly set it False) instead of erroring on the contradiction.
        if not self.sft_cfg.remove_microbatch_padding:
            logger.warning(
                "use_sequence_packing=True requires the THD layout; "
                "setting remove_microbatch_padding=True (was False)."
            )
            self.sft_cfg.remove_microbatch_padding = True

    # ------------------------------------------------------------------ #
    # Setup
    # ------------------------------------------------------------------ #

    def setup(self):
        """Initialize tokenizer, workers, dispatch, and tracker.

        Ray must already be initialized before calling this (either via
        ``initialize_ray`` on the head node or inside a Ray task).
        """
        self.tokenizer = get_tokenizer(
            self.cfg.trainer.policy.model.path,
            trust_remote_code=True,
            use_fast=not self.cfg.trainer.disable_fast_tokenizer,
            padding_side="left",
        )
        self.collator = self._build_collator(self.tokenizer)
        self._init_tracker()
        self._init_workers()

    def _init_workers(self):
        """Create PPORayActorGroup and WorkerDispatch.

        Selects the correct PolicyWorker based on strategy.
        """
        if self.sft_cfg.strategy == "megatron":
            from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
                PolicyWorker,
            )
        else:
            from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker

        num_gpus = self.sft_cfg.placement.num_gpus_per_node
        raw_pg = placement_group(
            [{"GPU": num_gpus, "CPU": num_gpus}] * self.sft_cfg.placement.num_nodes,
            strategy="PACK",
        )
        get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
        pg = ResolvedPlacementGroup(raw_pg)

        actor_group = PPORayActorGroup(
            self.cfg.trainer,
            num_nodes=self.sft_cfg.placement.num_nodes,
            num_gpus_per_node=num_gpus,
            ray_actor_type=PolicyWorker,
            pg=pg,
            num_gpus_per_actor=1,
            colocate_all=False,
            sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size,
            record_memory=self.cfg.trainer.policy.record_memory,
        )
        num_training_steps = (
            self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps
        )
        # num_steps may be None when num_epochs is used; the worker will use its
        # default (large value) for the LR scheduler in that case.
        ray.get(
            actor_group.async_init_model(
                self.sft_cfg.model.path,
                num_training_steps=num_training_steps,
            )
        )
        ray.get(actor_group.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))

        self.dispatch = WorkerDispatch(self.cfg, policy_actor_group=actor_group)

    def _init_tracker(self):
        self.tracker = Tracking(
            project_name=self.cfg.trainer.project_name,
            experiment_name=self.cfg.trainer.run_name,
            backends=self.cfg.trainer.logger,
            config=self.sft_cfg,
            tags=self.cfg.trainer.tags,
        )

    def add_callback(self, callback: TrainingCallback) -> None:
        """Register a callback. Can be called anytime; events fired after this
        call will reach the new callback."""
        self._callback_handler.add(callback)

    def _build_callback_input(self, **fields) -> CallbackInput:
        """Snapshot loop counters + per-event fields into a CallbackInput."""
        return CallbackInput(
            global_step=self.global_step,
            epoch=self._current_epoch,
            total_steps=self._total_steps,
            steps_per_epoch=self._steps_per_epoch,
            **fields,
        )

    def _fire(self, event_name: str, **fields) -> None:
        """Build a CallbackInput and dispatch the given event to all callbacks."""
        cb_input = self._build_callback_input(**fields)
        getattr(self._callback_handler, event_name)(self, cb_input, self._training_control)

    # ------------------------------------------------------------------ #
    # Data
    # ------------------------------------------------------------------ #

    def _load_and_tokenize(self, dataset_name: str, dataset_split: str) -> list:
        """Load and tokenize a dataset with caching support.

        Auto-detects the dataset format based on column names:
        - If a ``messages_key`` column exists, uses chat-format tokenization.
        - If ``instruction`` and ``output`` columns exist, uses Alpaca-format
          tokenization.

        Uses manual multiprocessing for parallel tokenization when num_workers > 0.
        With parallel mode, uses slice-based loading where each worker loads its
        own data slice directly from HuggingFace to eliminate pickle overhead.

        Caching:
        - Tokenized datasets are cached to disk as a HuggingFace ``Dataset``
          (arrow-backed, memory-mapped) for reuse across runs.
        - Cache key is a hash of dataset name, split, model, and tokenization params.
        - Set ``force_recache=True`` to ignore cache and re-tokenize.
        - Set ``disable_cache=True`` to disable caching entirely.

        Args:
            dataset_name: HuggingFace dataset name (e.g. ``"yahma/alpaca-cleaned"``).
            dataset_split: Dataset split (e.g. ``"train[:100]"`` or ``"test"``).

        Returns a list of tokenized examples (dicts with ``input_ids``,
        ``attention_mask``, ``num_actions``).
        """
        # Check cache first (unless disabled or force_recache)
        if not self.sft_cfg.disable_cache:
            cache_dir = self.sft_cfg.cache_dir

            # Compute cache key
            tools_key = self.sft_cfg.tools_key if self.sft_cfg.tools_key else None
            system_key = self.sft_cfg.system_key if self.sft_cfg.system_key else None
            cache_key = _compute_cache_key(
                dataset_name=dataset_name,
                dataset_split=dataset_split,
                model_path=self.sft_cfg.model.path,
                max_length=self.sft_cfg.max_length,
                messages_key=self.sft_cfg.messages_key,
                train_on_what=self.sft_cfg.train_on_what.value,
                tools_key=tools_key,
                system_key=system_key,
            )
            cache_path = _get_cache_path(cache_dir, cache_key)

            # Try to load from cache (unless force_recache)
            if not self.sft_cfg.force_recache:
                cached = _load_from_cache(cache_path)
                if cached is not None:
                    return cached

            logger.info("Cache miss or force_recache=True, tokenizing dataset...")
            logger.info(f"Cache key: {cache_key}")

        logger.info(f"Loading dataset '{dataset_name}' split='{dataset_split}'...")
        dataset = load_dataset(dataset_name, split=dataset_split)

        columns = dataset.column_names
        num_workers = self.sft_cfg.num_workers

        # Sequential tokenization path
        if num_workers == 0:
            logger.info("Tokenizing dataset (sequential)...")
            if self.sft_cfg.messages_key in columns:
                tools_key = self.sft_cfg.tools_key if self.sft_cfg.tools_key in columns else None
                system_key = self.sft_cfg.system_key if self.sft_cfg.system_key in columns else None
                tokenized = [
                    tokenize_chat_example(
                        ex,
                        self.tokenizer,
                        self.sft_cfg.max_length,
                        self.sft_cfg.messages_key,
                        train_on_what=self.sft_cfg.train_on_what,
                        tools_key=tools_key,
                        system_key=system_key,
                    )
                    for ex in dataset
                ]
            elif "instruction" in columns and "output" in columns:
                tokenized = [tokenize_sft_example(ex, self.tokenizer, self.sft_cfg.max_length) for ex in dataset]
            else:
                raise ValueError(
                    f"Unrecognized dataset format. Expected '{self.sft_cfg.messages_key}' column "
                    f"(chat format) or 'instruction'+'output' columns (Alpaca format). "
                    f"Found columns: {columns}"
                )
            tokenized = [ex for ex in tokenized if ex is not None]
            logger.info(f"Tokenized {len(tokenized)} examples (filtered from {len(dataset)})")

            # Save to cache if enabled
            if not self.sft_cfg.disable_cache:
                # TODO (sumanthrh): Currently we use a simple list instead of dataset + stateful dataloader
                # for simplicity but for caching we use HF Dataset since file sizes can get large
                # We should migrate to using HF datasets + a dataloader so that we don't materialize
                # the full dataset in memory
                _save_to_cache(cache_path, tokenized)

            return tokenized

        # Parallel tokenization path with slice-based loading
        logger.info(f"Tokenizing dataset with {num_workers} workers (slice-based loading)...")

        # Cache tokenizer to temp dir for fast worker loading
        tokenizer_cache_dir = tempfile.mkdtemp(prefix="skyrl_tokenizer_")
        try:
            self.tokenizer.save_pretrained(tokenizer_cache_dir)

            # Slice the already-loaded dataset; the original split string is
            # forwarded to workers verbatim so HF parses it (no local regex).
            dataset_size = len(dataset)
            chunk_size = max(1, dataset_size // num_workers)

            # Generate worker slice boundaries
            worker_args = []
            for worker_idx in range(num_workers):
                worker_start = worker_idx * chunk_size
                # Last worker takes any remainder
                if worker_idx == num_workers - 1:
                    worker_end = dataset_size
                else:
                    worker_end = min((worker_idx + 1) * chunk_size, dataset_size)

                # Skip empty slices
                if worker_start >= worker_end:
                    continue

                # Prepare worker arguments based on format
                if self.sft_cfg.messages_key in columns:
                    tools_key = self.sft_cfg.tools_key if self.sft_cfg.tools_key in columns else None
                    system_key = self.sft_cfg.system_key if self.sft_cfg.system_key in columns else None
                    worker_args.append(
                        (
                            dataset_name,
                            dataset_split,
                            worker_start,
                            worker_end,
                            tokenizer_cache_dir,
                            self.sft_cfg.max_length,
                            self.sft_cfg.messages_key,
                            self.sft_cfg.train_on_what.value,
                            tools_key,
                            system_key,
                        )
                    )
                elif "instruction" in columns and "output" in columns:
                    worker_args.append(
                        (
                            dataset_name,
                            dataset_split,
                            worker_start,
                            worker_end,
                            tokenizer_cache_dir,
                            self.sft_cfg.max_length,
                        )
                    )
                else:
                    raise ValueError(
                        f"Unrecognized dataset format. Expected '{self.sft_cfg.messages_key}' column "
                        f"(chat format) or 'instruction'+'output' columns (Alpaca format). "
                        f"Found columns: {columns}"
                    )

            # Select worker function based on format
            if self.sft_cfg.messages_key in columns:
                worker_fn = _tokenize_chat_slice_worker
            else:
                worker_fn = _tokenize_alpaca_slice_worker

            logger.info(f"Dividing {dataset_size} examples among {len(worker_args)} workers")

            # Use spawn to avoid Ray fork issues
            ctx = mp.get_context("spawn")

            # Process in parallel
            with ctx.Pool(processes=num_workers) as pool:
                results = pool.map(worker_fn, worker_args)

            # Flatten results
            tokenized = []
            for chunk_results in results:
                tokenized.extend(chunk_results)

            logger.info(f"Tokenized {len(tokenized)} examples (filtered from {dataset_size})")

            # Save to cache if enabled
            if not self.sft_cfg.disable_cache:
                _save_to_cache(cache_path, tokenized)

            return tokenized

        finally:
            # Cleanup temp tokenizer cache
            import shutil

            shutil.rmtree(tokenizer_cache_dir, ignore_errors=True)

    def load_dataset(self) -> list:
        """Load and tokenize the training dataset."""
        return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)

    def load_eval_dataset(self) -> Optional[list]:
        """Load and tokenize the eval dataset, or return ``None`` if not configured."""
        if not self.sft_cfg.eval_dataset_name:
            return None
        return self._load_and_tokenize(self.sft_cfg.eval_dataset_name, self.sft_cfg.eval_dataset_split)

    def _log_dataset_stats(self, tokenized: list) -> None:
        """Log tokenized sequence length statistics over the training set.

        Reports count, mean, median (q50), q25, q75, min, max of the tokenized
        ``input_ids`` lengths. Logs once via ``logger.info``.
        """
        if not tokenized:
            logger.warning("No tokenized examples to compute stats over")
            return

        lengths = [len(ex["input_ids"]) for ex in tokenized]
        n = len(lengths)
        sorted_lengths = sorted(lengths)

        def pct(p: float) -> int:
            # Simple nearest-rank percentile over ints; adequate for dataset stats.
            idx = max(0, min(n - 1, int(round((p / 100.0) * (n - 1)))))
            return sorted_lengths[idx]

        mean_len = sum(lengths) / n
        q25 = pct(25)
        q50 = pct(50)
        q75 = pct(75)
        min_len = sorted_lengths[0]
        max_len = sorted_lengths[-1]

        logger.info(
            f"Dataset stats (tokenized lengths over {n} examples):\n"
            f"total={sum(lengths)}, mean={mean_len:.1f}, median={q50}, q25={q25}, q75={q75}, min={min_len}, max={max_len}"
        )

    def collate_batch(self, examples: list, batch_size: int) -> TrainingInputBatch:
        """Collate examples into a TrainingInputBatch via the configured collator.

        Delegates to ``self.collator`` (``DefaultCollator`` or, when sequence
        packing is enabled, ``PackedDataCollator``).

        Args:
            examples: Tokenized examples to collate.
            batch_size: Global batch dimension. The train path passes
                ``sft_cfg.batch_size`` and the eval path passes its
                per-dispatch chunk size.
        """
        return self.collator(examples, batch_size=batch_size)

    # ------------------------------------------------------------------ #
    # Checkpoint resume
    # ------------------------------------------------------------------ #

    def load_checkpoint(self) -> int:
        """Load a checkpoint and return the step number to resume from.

        Behaviour depends on ``sft_cfg.resume_from``:
        - ``""`` (empty): no resume, return 0.
        - ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
        - otherwise: treat as a direct path to a ``global_step_N`` directory.

        Returns:
            The global step to resume from (0 if no checkpoint loaded).
        """
        resume_from = self.sft_cfg.resume_from
        if not resume_from:
            return 0

        if resume_from == "latest":
            if not self.sft_cfg.ckpt_path:
                logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
                return 0
            latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_file):
                logger.info("No latest checkpoint marker found, starting from scratch")
                return 0
            with io.open_file(latest_file, "r") as f:
                ckpt_step = int(f.read().strip())
            checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
            # Validate consistency: ensure no stale checkpoint folders from prior runs
            validate_consistency_for_latest_checkpoint(
                self.sft_cfg.ckpt_path,
                ckpt_step,
                checkpoint_path,
                latest_file,
                self.sft_cfg.ckpt_interval,
            )
        else:
            checkpoint_path = resume_from

        if not io.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        global_step = extract_step_from_path(checkpoint_path)
        if global_step == -1:
            raise ValueError(
                f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
                f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
            )

        # Load and validate trainer state if available
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        if io.exists(trainer_state_path):
            with io.open_file(trainer_state_path, "rb") as f:
                trainer_state = torch.load(f, map_location="cpu", weights_only=False)
            saved_global_step = trainer_state.get("global_step", global_step)
            logger.info("Successfully loaded trainer state")
            if saved_global_step != global_step:
                logger.warning(
                    f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
                )
        else:
            logger.warning(
                f"No trainer_state.pt found at {trainer_state_path}. "
                "This checkpoint was likely saved by an older version."
            )

        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info(f"Successfully resumed from global_step_{global_step}")
        return global_step

    # ------------------------------------------------------------------ #
    # Training
    # ------------------------------------------------------------------ #

    def run_eval(self, eval_tokenized: list) -> tuple[dict, int]:
        """Compute eval loss over the full eval dataset.

        Iterates the eval dataset in chunks of ``micro_train_batch_size_per_gpu * dp_size``
        (i.e. exactly one micro-batch per DP rank per dispatch call), calls
        :meth:`WorkerDispatch.forward` with ``loss_fn="cross_entropy"`` (which
        runs the model in ``eval()`` mode under ``no_grad``), and aggregates the
        per-batch losses into a token-weighted mean.

        The aggregated loss is a token-weighted mean of the per-batch losses,
        which are themselves per-non-pad-token means within each batch. This
        yields the true per-non-pad-token mean across the eval dataset.

        Args:
            eval_tokenized: Pre-tokenized eval dataset (output of
                :meth:`load_eval_dataset`).

        Returns:
            ``(metrics, num_eval_batches)`` where ``metrics`` contains
            ``eval_loss`` and ``num_eval_batches`` is bookkeeping for
            stdout logging (not a wandb metric).
        """
        num_eval = len(eval_tokenized)
        if num_eval == 0:
            raise ValueError(
                "Eval dataset is empty. Provide a non-empty eval split or disable eval "
                "by setting eval_dataset_name=None."
            )

        # One micro-batch per DP rank per dispatch call — keeps memory usage bounded
        # and removes the need for a separate `eval_batch_size` knob.
        dp_size = self.dispatch.dp_size("policy")
        eval_chunk_size = self.sft_cfg.micro_train_batch_size_per_gpu * dp_size

        # Pad a trailing partial batch up to ``eval_chunk_size`` via
        # ``pad_training_input_batch`` (which zeros ``loss_mask`` on padded rows).
        # Padded rows contribute 0 to the cross-entropy numerator, and the
        # pre-padding ``total_nonpad`` scaling in ``collate_batch`` excludes
        # them from the denominator, so the reported ``eval_loss`` is the
        # per-real-token mean over the full (non-padded) eval set.
        num_eval_batches = ceil(num_eval / eval_chunk_size)

        total_loss_weighted = 0.0
        total_tokens = 0
        for batch_idx in range(num_eval_batches):
            start = batch_idx * eval_chunk_size
            end = min(start + eval_chunk_size, num_eval)
            batch_examples = eval_tokenized[start:end]
            batch = self.collator(batch_examples, batch_size=eval_chunk_size)
            # Pad the last (possibly-short) chunk so every dispatch sees exactly
            # ``eval_chunk_size`` rows. ``pad_training_input_batch`` zeros the
            # ``loss_mask`` for padding rows; with ``pad_size=0`` it is a no-op.
            pad_rows = eval_chunk_size - len(batch_examples)
            if pad_rows > 0:
                logger.info(
                    f"Padding final eval batch by {pad_rows} rows "
                    f"({len(batch_examples)} real -> {eval_chunk_size} total); "
                    f"padded rows are masked out of the loss."
                )
                batch = pad_training_input_batch(batch, pad_rows)
            # Count non-pad response tokens (from the unscaled mask, recovered from the batch)
            # We use the attention_mask response window via collate_sft_batch's loss_mask which
            # was 0/1 before scaling. Recover the count from the batch by counting positive entries.
            # Padded rows have loss_mask=0 so they are excluded here.
            nonpad_tokens = int((batch["loss_mask"] > 0).sum().item())
            output = self.dispatch.forward(
                "policy",
                batch,
                loss_fn="cross_entropy",
                loss_fn_config=None,
            )
            batch_loss = float(output.metrics.get("loss", float("nan")))
            total_loss_weighted += batch_loss * nonpad_tokens
            total_tokens += nonpad_tokens

        eval_loss = total_loss_weighted / max(total_tokens, 1)
        return {"eval_loss": eval_loss}, num_eval_batches

    def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
        """Execute a single training step: forward_backward + optim_step.

        Args:
            batch: The collated training batch.
            step: Current global step (reserved for future use, e.g. scheduling).

        Returns:
            Dict with ``loss``, ``grad_norm``, and ``timings``.
        """
        timings: dict[str, float] = {}
        with Timer("forward_backward", timings):
            output = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
        with Timer("optim_step", timings):
            grad_norm = self.dispatch.optim_step("policy")

        metrics = output.metrics
        loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
        return {
            "loss": loss_val,
            "grad_norm": grad_norm,
            "timings": timings,
        }

    def _validate_batch_parallelism(self):
        """Validate that batch_size is compatible with data-parallel and micro-batch sizes."""
        batch_size = self.sft_cfg.batch_size
        total_gpus = self.sft_cfg.placement.num_nodes * self.sft_cfg.placement.num_gpus_per_node
        if self.sft_cfg.use_sequence_packing:
            # With packing, batch_size is the *example* count (not bins) and the
            # per-DP-rank bin count == bins_per_shard. The worker micro batch
            # size refers to bin rows per micro-batch, derived from the
            # ``max_tokens_per_microbatch`` token budget. We only require
            # batch_size >= dp_size (every DP rank needs >= 1 bin) and do NOT
            # require batch_size % micro_train_batch_size_per_gpu == 0, because
            # micro_train_batch_size_per_gpu refers to bins-per-MB, not
            # examples-per-MB; FFD rounds the bin count up to a multiple of
            # dp_size, and bins/MB is a separate knob.
            dp_size = self._dp_size()
            if batch_size < dp_size:
                raise ValueError(
                    f"batch_size ({batch_size}) must be >= dp_size ({dp_size}) when "
                    f"use_sequence_packing=True (each DP rank needs at least one bin)."
                )
            return
        if self.sft_cfg.strategy == "megatron":
            tp = self.sft_cfg.megatron_config.tensor_model_parallel_size
            pp = self.sft_cfg.megatron_config.pipeline_model_parallel_size
            dp_size = total_gpus // (tp * pp)
        else:
            # FSDP: all GPUs are data-parallel
            dp_size = total_gpus
        if batch_size % dp_size != 0:
            raise ValueError(f"batch_size ({batch_size}) must be divisible by data-parallel size ({dp_size})")
        per_dp_batch = batch_size // dp_size
        micro_batch = self.sft_cfg.micro_train_batch_size_per_gpu
        if per_dp_batch % micro_batch != 0:
            raise ValueError(
                f"batch_size ({self.sft_cfg.batch_size}) / dp_size ({dp_size}) must be divisible by "
                f"micro_train_batch_size_per_gpu ({micro_batch})"
            )

    def _build_dummy_batch(self) -> TrainingInputBatch:
        """Build a dummy batch of random full-context sequences for benchmarking."""
        batch_size = self.sft_cfg.batch_size
        max_length = self.sft_cfg.max_length
        micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
        vocab_size = self.tokenizer.vocab_size

        # num_actions is max_length - 1 because the autoregressive model
        # produces log-probs for positions 1..T (predicting next token),
        # so the first token has no corresponding log-prob.
        num_actions = max_length - 1

        sequences = torch.randint(0, vocab_size, (batch_size, max_length), dtype=torch.long)
        attention_mask = torch.ones(batch_size, max_length, dtype=torch.long)
        # All tokens are non-pad in the dummy batch, so total_nonpad = batch_size * num_actions.
        # Scaling = batch_size / (micro_batch_size * total_nonpad)
        #         = 1 / (micro_batch_size * num_actions)
        total_nonpad = batch_size * num_actions
        loss_mask = torch.ones(batch_size, num_actions, dtype=torch.float) * (
            batch_size / (micro_batch_size * total_nonpad)
        )

        batch = TrainingInputBatch(
            {
                "sequences": sequences,
                "attention_mask": attention_mask,
                "loss_mask": loss_mask,
            }
        )
        batch.metadata = {"response_length": num_actions}
        return batch

    def _train_dummy(self):
        """Dummy training loop for benchmarking. Skips real data, checkpoints, and resume."""
        self._validate_batch_parallelism()
        batch = self._build_dummy_batch()
        num_steps = self.sft_cfg.dummy_run_max_steps

        logger.info(
            f"Starting dummy SFT training for {num_steps} steps "
            f"(batch_size={self.sft_cfg.batch_size}, max_length={self.sft_cfg.max_length})..."
        )

        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.start()
        for step in range(num_steps):
            all_timings: dict[str, float] = {}

            with Timer("step", all_timings):
                step_result = self.train_step(batch, step)
                all_timings.update(step_result["timings"])

            actual_num_tokens = batch["attention_mask"].sum().item()
            self._total_tokens_processed += actual_num_tokens
            tokens_per_second = actual_num_tokens / all_timings["step"]

            log_dict = {
                "train/loss": step_result["loss"],
                "train/grad_norm": step_result["grad_norm"],
                "train/tokens_per_second": tokens_per_second,
                "train/tokens_per_second_per_gpu": tokens_per_second / self._num_training_gpus,
                "train/actual_num_tokens": actual_num_tokens,
                "train/total_tokens_processed": self._total_tokens_processed,
            }
            log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
            if self._ray_gpu_monitor is not None:
                log_dict.update(self._ray_gpu_monitor.flush())

            self.tracker.log(log_dict, step=step, commit=True)
            logger.info(
                f"Step {step}: loss={step_result['loss']:.4f}, "
                f"grad_norm={step_result['grad_norm']}, "
                f"tokens_per_second={tokens_per_second:.0f}"
            )

        logger.info("Dummy SFT training complete!")

    def train(self):
        """Full training loop: load data, iterate, log, checkpoint."""
        if self.sft_cfg.dummy_run_full_ctx:
            if self.sft_cfg.resume_from:
                logger.warning("resume_from is ignored in dummy run mode")
            return self._train_dummy()

        tokenized = self.load_dataset()

        # Log tokenized sequence length statistics (once, before training loop)
        self._log_dataset_stats(tokenized)

        # Load eval dataset (if configured). We load once up-front so the
        # tokenization cost is amortized across all eval invocations.
        eval_tokenized = self.load_eval_dataset()
        if eval_tokenized is not None:
            logger.info(f"Eval dataset loaded: {len(eval_tokenized)} examples")

        batch_size = self.sft_cfg.batch_size

        # steps_per_epoch is always derived from the data; callbacks rely on it.
        steps_per_epoch = max(1, ceil(len(tokenized) / batch_size))

        # Resolve num_steps: explicit num_steps takes precedence; otherwise derive from num_epochs.
        if self.sft_cfg.num_steps is not None:
            num_steps = self.sft_cfg.num_steps
        else:
            num_steps = self.sft_cfg.num_epochs * steps_per_epoch
            logger.info(
                f"num_steps not set; deriving from num_epochs={self.sft_cfg.num_epochs}: "
                f"ceil({len(tokenized)} / {batch_size}) * {self.sft_cfg.num_epochs} = {num_steps} steps"
            )

        # Early validation: dataset must have at least batch_size examples
        if len(tokenized) < batch_size:
            raise ValueError(
                f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
                f"Reduce batch_size or use more data."
            )

        self._validate_batch_parallelism()

        # Resume from checkpoint if configured
        start_step = self.load_checkpoint()

        # Shuffle data before training
        rng = random.Random(self.sft_cfg.seed)
        rng.shuffle(tokenized)

        # When resuming, start_step is the last *completed* step (checkpoint is
        # saved AFTER the optimizer update), so we begin at start_step + 1 to
        # avoid replaying that step.

        # Replay epoch shuffles for reproducibility on resume
        start_epoch = (start_step * batch_size) // len(tokenized)
        for _ in range(start_epoch):
            rng.shuffle(tokenized)
        current_epoch = start_epoch

        # Initialize `global_step`
        self.global_step = start_step

        # Publish loop metadata so CallbackInput can be built consistently.
        self._total_steps = num_steps
        self._steps_per_epoch = steps_per_epoch
        self._current_epoch = current_epoch
        self._training_control.reset()

        logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
        if start_step > 0:
            logger.info(f"Resuming from step {start_step}")

        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.start()

        # Tracks whether the most recent in-loop iteration saved a checkpoint
        # (either via the ckpt_interval or via a callback-driven ``should_save``).
        did_save_last_step = False

        self._fire("on_train_start")

        # Baseline eval before training begins (logged at step 0).
        # Wandb's step counter starts at 0; the training loop's first commit
        # advances it to >=1, so step=0 here does not conflict with later steps.
        if self.sft_cfg.eval_before_train and eval_tokenized is not None:
            self._fire("on_eval_start")
            eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
            self._fire("on_eval_end", metrics=eval_metrics)
            baseline_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
            self._fire("on_log", logs=baseline_log)
            self.tracker.log(baseline_log, step=self.global_step, commit=True)
            logger.info(
                f"Baseline eval before training: "
                f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
                f"over {num_eval_batches} batches"
            )

        # SkyRL starts counting at step 1
        self.global_step = start_step + 1 if start_step > 0 else 1
        self._fire("on_epoch_start")
        epoch_in_progress = True

        while self.global_step <= num_steps:
            all_timings: dict[str, float] = {}

            with Timer("step", all_timings):

                # Data loading with wrap-around
                with Timer("data_loading", all_timings):
                    start_idx = (self.global_step * batch_size) % len(tokenized)
                    end_idx = start_idx + batch_size
                    if end_idx > len(tokenized):
                        batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
                    else:
                        batch_examples = tokenized[start_idx:end_idx]
                    batch = self.collator(batch_examples, batch_size=batch_size)

                self._fire("on_step_start", batch=batch)

                # Training step
                step_result = self.train_step(batch, self.global_step)
                all_timings.update(step_result["timings"])

            # Compute throughput using actual (non-padding) tokens
            batch_padded_seq_len = batch["sequences"].shape[1]
            actual_num_tokens = batch["attention_mask"].sum().item()
            self._total_tokens_processed += actual_num_tokens
            tokens_per_second = actual_num_tokens / all_timings["step"]

            # Build log dict
            log_dict = {
                "train/loss": step_result["loss"],
                "train/grad_norm": step_result["grad_norm"],
                "train/tokens_per_second": tokens_per_second,
                "train/tokens_per_second_per_gpu": tokens_per_second / self._num_training_gpus,
                "train/actual_num_tokens": actual_num_tokens,
                "train/batch_padded_seq_len": batch_padded_seq_len,
                "train/total_tokens_processed": self._total_tokens_processed,
            }
            log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
            if self._ray_gpu_monitor is not None:
                log_dict.update(self._ray_gpu_monitor.flush())

            self._fire("on_step_end", batch=batch, metrics=step_result)

            # Capture callback-driven triggers, then reset so they only fire once.
            force_save = self._training_control.should_save
            force_eval = self._training_control.should_evaluate
            self._training_control.should_save = False
            self._training_control.should_evaluate = False

            # Checkpoint: interval-driven or callback-requested.
            interval_save = (
                self.sft_cfg.ckpt_interval > 0
                and self.global_step > 0
                and self.global_step % self.sft_cfg.ckpt_interval == 0
            )
            did_save_last_step = force_save or interval_save
            if did_save_last_step:
                with Timer("save_checkpoint", all_timings):
                    ckpt_path = self.save_checkpoint()
                log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]
                self._fire("on_save", ckpt_path=ckpt_path)

            # HF export at regular intervals
            if self.sft_cfg.hf_save_interval > 0 and self.global_step % self.sft_cfg.hf_save_interval == 0:
                with Timer("save_hf_model", all_timings):
                    self.save_hf_model()
                log_dict["timing/save_hf_model"] = all_timings["save_hf_model"]

            eval_metrics = None
            num_eval_batches: int | None = None
            # Eval fires at step N where N % eval_interval == 0 and N > 0, OR
            # whenever a callback set ``control.should_evaluate``.
            interval_eval = self.sft_cfg.eval_interval > 0 and self.global_step % self.sft_cfg.eval_interval == 0
            if eval_tokenized is not None and (force_eval or interval_eval):
                self._fire("on_eval_start")
                with Timer("eval", all_timings):
                    eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
                self._fire("on_eval_end", metrics=eval_metrics)
                if eval_metrics:
                    log_dict.update({f"eval/{k}": v for k, v in eval_metrics.items()})
                    log_dict["timing/eval"] = all_timings["eval"]

            log_dict.update({"train/epoch": current_epoch, "train/global_step": self.global_step})
            # Callbacks may mutate log_dict in place via on_log.
            self._fire("on_log", logs=log_dict)
            self.tracker.log(log_dict, step=self.global_step, commit=True)

            if self.global_step % 5 == 0:
                logger.info(
                    f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
                )

            if eval_metrics:
                logger.info(
                    f"Step {self.global_step}: eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
                    f"over {num_eval_batches} batches"
                )

            # Check for epoch boundary and reshuffle
            epoch = (self.global_step * batch_size) // len(tokenized)
            if epoch > current_epoch:
                self._fire("on_epoch_end")
                epoch_in_progress = False
                for _ in range(epoch - current_epoch):
                    rng.shuffle(tokenized)
                current_epoch = epoch
                self._current_epoch = epoch
                if self.global_step + 1 <= num_steps:
                    self._fire("on_epoch_start")
                    epoch_in_progress = True

            self.global_step += 1
        self.global_step = min(self.global_step, num_steps)

        # Pair the leading on_epoch_start: fire on_epoch_end if we exited the
        # loop mid-epoch
        if epoch_in_progress:
            self._fire("on_epoch_end")
            epoch_in_progress = False

        # Save final checkpoint (if checkpointing is enabled). Skip if the last
        # in-loop iteration already saved (either via ckpt_interval or via a
        # callback-driven force-save) so we don't double-save.
        if self.sft_cfg.ckpt_path and not did_save_last_step:
            final_step = num_steps
            logger.info(f"Saving final checkpoint at step {final_step}")
            ckpt_path = self.save_checkpoint()
            self._fire("on_save", ckpt_path=ckpt_path)

        # Save final HF model if enabled (only if not already saved at last step)
        if self.sft_cfg.hf_save_interval > 0:
            final_step = num_steps
            already_saved = final_step % self.sft_cfg.hf_save_interval == 0
            if not already_saved:
                self.global_step = final_step
                logger.info(f"Saving final HF model at step {final_step}")
                self.save_hf_model()

        # Final eval pass (skip if the last step already ran eval).
        # NOTE: The last in-loop tracker.log(..., commit=True) at step=num_steps
        # advanced wandb's internal step counter to num_steps+1. Logging the
        # final eval at step=num_steps would be rejected by wandb with
        # "step N < current step N+1". We log the final eval at num_steps+1
        # (one past the last committed train step) in a single combined
        # tracker.log() call, preserving wandb step ordering. We use a local
        # ``final_eval_step`` rather than mutating ``self.global_step``: the
        # bump is purely a wandb-step accounting concern, not real trainer
        # state.
        if eval_tokenized is not None:
            already_ran = self.sft_cfg.eval_interval > 0 and num_steps % self.sft_cfg.eval_interval == 0
            if not already_ran:
                final_eval_step = num_steps + 1
                eval_timings: dict[str, float] = {}
                self._fire("on_eval_start")
                with Timer("eval", eval_timings):
                    eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
                self._fire("on_eval_end", metrics=eval_metrics)
                if eval_metrics:
                    eval_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
                    eval_log["timing/eval"] = eval_timings["eval"]
                    self._fire("on_log", logs=eval_log)
                    self.tracker.log(eval_log, step=final_eval_step, commit=True)
                    logger.info(
                        f"Final eval at step {final_eval_step}: "
                        f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
                        f"over {num_eval_batches} batches"
                    )

        self._fire("on_train_end")
        logger.info("SFT training complete!")

    def save_checkpoint(self) -> str:
        """Save a checkpoint at the given step. Returns the checkpoint folder path."""
        step = self.global_step
        global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        io.makedirs(global_step_folder, exist_ok=True)
        logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
        trainer_state = {
            "global_step": step,
            "config": asdict(self.sft_cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking -- write this last after all saves succeed
        latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_file, "w") as f:
            f.write(str(step))
        logger.info(f"Checkpoint saved for global_step_{step}")

        # Clean up old checkpoints after successful save
        cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)
        return global_step_folder

    def save_hf_model(self):
        """Save policy weights in HuggingFace format.

        Export path: cfg.trainer.export_path/global_step_{step}/policy
        Mirrors the pattern used by the RL trainer's save_models().
        """
        step = self.global_step
        policy_export_dir = os.path.join(
            self.cfg.trainer.export_path,
            f"{GLOBAL_STEP_PREFIX}{step}",
            "policy",
        )
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
        logger.info(f"Saved HF model weights at step {step} to {policy_export_dir}")

    # ------------------------------------------------------------------ #
    # Lifecycle
    # ------------------------------------------------------------------ #

    def shutdown(self):
        """Finish tracking.

        Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
        (the normal path via ``sft_entrypoint``), shutting down Ray from
        within the task would be incorrect.  The head-node process owns
        the Ray lifecycle.
        """
        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.stop()
        if self.tracker is not None:
            self.tracker.finish()

attr sft_cfg

sft_cfg = cfg

attr cfg

cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)

attr tokenizer

tokenizer = None

attr dispatch

dispatch: WorkerDispatch | None = None

attr tracker

tracker: Tracking | None = None

attr global_step

global_step = 0

attr collator

collator = None

method setup

setup()

Initialize tokenizer, workers, dispatch, and tracker.

Ray must already be initialized before calling this (either via initialize_ray on the head node or inside a Ray task).

Source code in skyrl/train/sft_trainer.py:777-791
    def setup(self):
        """Initialize tokenizer, workers, dispatch, and tracker.

        Ray must already be initialized before calling this (either via
        ``initialize_ray`` on the head node or inside a Ray task).
        """
        self.tokenizer = get_tokenizer(
            self.cfg.trainer.policy.model.path,
            trust_remote_code=True,
            use_fast=not self.cfg.trainer.disable_fast_tokenizer,
            padding_side="left",
        )
        self.collator = self._build_collator(self.tokenizer)
        self._init_tracker()
        self._init_workers()

method add_callback

add_callback(callback: TrainingCallback) -> None

Register a callback. Can be called anytime; events fired after this call will reach the new callback.

Source code in skyrl/train/sft_trainer.py:848-851
    def add_callback(self, callback: TrainingCallback) -> None:
        """Register a callback. Can be called anytime; events fired after this
        call will reach the new callback."""
        self._callback_handler.add(callback)

method load_dataset

load_dataset() -> list

Load and tokenize the training dataset.

Source code in skyrl/train/sft_trainer.py:1068-1070
    def load_dataset(self) -> list:
        """Load and tokenize the training dataset."""
        return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)

method load_eval_dataset

load_eval_dataset() -> Optional[list]

Load and tokenize the eval dataset, or return None if not configured.

Source code in skyrl/train/sft_trainer.py:1072-1076
    def load_eval_dataset(self) -> Optional[list]:
        """Load and tokenize the eval dataset, or return ``None`` if not configured."""
        if not self.sft_cfg.eval_dataset_name:
            return None
        return self._load_and_tokenize(self.sft_cfg.eval_dataset_name, self.sft_cfg.eval_dataset_split)

method collate_batch

collate_batch(examples: list, batch_size: int) -> TrainingInputBatch

Collate examples into a TrainingInputBatch via the configured collator.

Delegates to self.collator (DefaultCollator or, when sequence packing is enabled, PackedDataCollator).

Parameters:

NameTypeDescriptionDefault
exampleslistTokenized examples to collate.required
batch_sizeintGlobal batch dimension. The train path passes sft_cfg.batch_size and the eval path passes its per-dispatch chunk size.required
Source code in skyrl/train/sft_trainer.py:1109-1121
    def collate_batch(self, examples: list, batch_size: int) -> TrainingInputBatch:
        """Collate examples into a TrainingInputBatch via the configured collator.

        Delegates to ``self.collator`` (``DefaultCollator`` or, when sequence
        packing is enabled, ``PackedDataCollator``).

        Args:
            examples: Tokenized examples to collate.
            batch_size: Global batch dimension. The train path passes
                ``sft_cfg.batch_size`` and the eval path passes its
                per-dispatch chunk size.
        """
        return self.collator(examples, batch_size=batch_size)

method abstractmethod load_checkpoint

load_checkpoint() -> int

Load a checkpoint and return the step number to resume from.

Behaviour depends on sft_cfg.resume_from:

  • "" (empty): no resume, return 0.
  • "latest": read latest_ckpt_global_step.txt from ckpt_path.
  • otherwise: treat as a direct path to a global_step_N directory.

Returns:

TypeDescription
intThe global step to resume from (0 if no checkpoint loaded).
Source code in skyrl/train/sft_trainer.py:1127-1200
    def load_checkpoint(self) -> int:
        """Load a checkpoint and return the step number to resume from.

        Behaviour depends on ``sft_cfg.resume_from``:
        - ``""`` (empty): no resume, return 0.
        - ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
        - otherwise: treat as a direct path to a ``global_step_N`` directory.

        Returns:
            The global step to resume from (0 if no checkpoint loaded).
        """
        resume_from = self.sft_cfg.resume_from
        if not resume_from:
            return 0

        if resume_from == "latest":
            if not self.sft_cfg.ckpt_path:
                logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
                return 0
            latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_file):
                logger.info("No latest checkpoint marker found, starting from scratch")
                return 0
            with io.open_file(latest_file, "r") as f:
                ckpt_step = int(f.read().strip())
            checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
            # Validate consistency: ensure no stale checkpoint folders from prior runs
            validate_consistency_for_latest_checkpoint(
                self.sft_cfg.ckpt_path,
                ckpt_step,
                checkpoint_path,
                latest_file,
                self.sft_cfg.ckpt_interval,
            )
        else:
            checkpoint_path = resume_from

        if not io.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        global_step = extract_step_from_path(checkpoint_path)
        if global_step == -1:
            raise ValueError(
                f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
                f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
            )

        # Load and validate trainer state if available
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        if io.exists(trainer_state_path):
            with io.open_file(trainer_state_path, "rb") as f:
                trainer_state = torch.load(f, map_location="cpu", weights_only=False)
            saved_global_step = trainer_state.get("global_step", global_step)
            logger.info("Successfully loaded trainer state")
            if saved_global_step != global_step:
                logger.warning(
                    f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
                )
        else:
            logger.warning(
                f"No trainer_state.pt found at {trainer_state_path}. "
                "This checkpoint was likely saved by an older version."
            )

        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info(f"Successfully resumed from global_step_{global_step}")
        return global_step

method run_eval

run_eval(eval_tokenized: list) -> tuple[dict, int]

Compute eval loss over the full eval dataset.

Iterates the eval dataset in chunks of micro_train_batch_size_per_gpu * dp_size (i.e. exactly one micro-batch per DP rank per dispatch call), calls :meth:WorkerDispatch.forward with loss_fn="cross_entropy" (which runs the model in eval() mode under no_grad), and aggregates the per-batch losses into a token-weighted mean.

The aggregated loss is a token-weighted mean of the per-batch losses, which are themselves per-non-pad-token means within each batch. This yields the true per-non-pad-token mean across the eval dataset.

Parameters:

NameTypeDescriptionDefault
eval_tokenizedlistPre-tokenized eval dataset (output of :meth:load_eval_dataset).required

Returns:

TypeDescription
dict(metrics, num_eval_batches) where metrics contains
inteval_loss and num_eval_batches is bookkeeping for
tuple[dict, int]stdout logging (not a wandb metric).
Source code in skyrl/train/sft_trainer.py:1206-1282
    def run_eval(self, eval_tokenized: list) -> tuple[dict, int]:
        """Compute eval loss over the full eval dataset.

        Iterates the eval dataset in chunks of ``micro_train_batch_size_per_gpu * dp_size``
        (i.e. exactly one micro-batch per DP rank per dispatch call), calls
        :meth:`WorkerDispatch.forward` with ``loss_fn="cross_entropy"`` (which
        runs the model in ``eval()`` mode under ``no_grad``), and aggregates the
        per-batch losses into a token-weighted mean.

        The aggregated loss is a token-weighted mean of the per-batch losses,
        which are themselves per-non-pad-token means within each batch. This
        yields the true per-non-pad-token mean across the eval dataset.

        Args:
            eval_tokenized: Pre-tokenized eval dataset (output of
                :meth:`load_eval_dataset`).

        Returns:
            ``(metrics, num_eval_batches)`` where ``metrics`` contains
            ``eval_loss`` and ``num_eval_batches`` is bookkeeping for
            stdout logging (not a wandb metric).
        """
        num_eval = len(eval_tokenized)
        if num_eval == 0:
            raise ValueError(
                "Eval dataset is empty. Provide a non-empty eval split or disable eval "
                "by setting eval_dataset_name=None."
            )

        # One micro-batch per DP rank per dispatch call — keeps memory usage bounded
        # and removes the need for a separate `eval_batch_size` knob.
        dp_size = self.dispatch.dp_size("policy")
        eval_chunk_size = self.sft_cfg.micro_train_batch_size_per_gpu * dp_size

        # Pad a trailing partial batch up to ``eval_chunk_size`` via
        # ``pad_training_input_batch`` (which zeros ``loss_mask`` on padded rows).
        # Padded rows contribute 0 to the cross-entropy numerator, and the
        # pre-padding ``total_nonpad`` scaling in ``collate_batch`` excludes
        # them from the denominator, so the reported ``eval_loss`` is the
        # per-real-token mean over the full (non-padded) eval set.
        num_eval_batches = ceil(num_eval / eval_chunk_size)

        total_loss_weighted = 0.0
        total_tokens = 0
        for batch_idx in range(num_eval_batches):
            start = batch_idx * eval_chunk_size
            end = min(start + eval_chunk_size, num_eval)
            batch_examples = eval_tokenized[start:end]
            batch = self.collator(batch_examples, batch_size=eval_chunk_size)
            # Pad the last (possibly-short) chunk so every dispatch sees exactly
            # ``eval_chunk_size`` rows. ``pad_training_input_batch`` zeros the
            # ``loss_mask`` for padding rows; with ``pad_size=0`` it is a no-op.
            pad_rows = eval_chunk_size - len(batch_examples)
            if pad_rows > 0:
                logger.info(
                    f"Padding final eval batch by {pad_rows} rows "
                    f"({len(batch_examples)} real -> {eval_chunk_size} total); "
                    f"padded rows are masked out of the loss."
                )
                batch = pad_training_input_batch(batch, pad_rows)
            # Count non-pad response tokens (from the unscaled mask, recovered from the batch)
            # We use the attention_mask response window via collate_sft_batch's loss_mask which
            # was 0/1 before scaling. Recover the count from the batch by counting positive entries.
            # Padded rows have loss_mask=0 so they are excluded here.
            nonpad_tokens = int((batch["loss_mask"] > 0).sum().item())
            output = self.dispatch.forward(
                "policy",
                batch,
                loss_fn="cross_entropy",
                loss_fn_config=None,
            )
            batch_loss = float(output.metrics.get("loss", float("nan")))
            total_loss_weighted += batch_loss * nonpad_tokens
            total_tokens += nonpad_tokens

        eval_loss = total_loss_weighted / max(total_tokens, 1)
        return {"eval_loss": eval_loss}, num_eval_batches

method train_step

train_step(batch: TrainingInputBatch, step: int) -> dict

Execute a single training step: forward_backward + optim_step.

Parameters:

NameTypeDescriptionDefault
batchTrainingInputBatchThe collated training batch.required
stepintCurrent global step (reserved for future use, e.g. scheduling).required

Returns:

TypeDescription
dictDict with loss, grad_norm, and timings.
Source code in skyrl/train/sft_trainer.py:1284-1306
    def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
        """Execute a single training step: forward_backward + optim_step.

        Args:
            batch: The collated training batch.
            step: Current global step (reserved for future use, e.g. scheduling).

        Returns:
            Dict with ``loss``, ``grad_norm``, and ``timings``.
        """
        timings: dict[str, float] = {}
        with Timer("forward_backward", timings):
            output = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
        with Timer("optim_step", timings):
            grad_norm = self.dispatch.optim_step("policy")

        metrics = output.metrics
        loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
        return {
            "loss": loss_val,
            "grad_norm": grad_norm,
            "timings": timings,
        }

method train

train()

Full training loop: load data, iterate, log, checkpoint.

Source code in skyrl/train/sft_trainer.py:1423-1694
    def train(self):
        """Full training loop: load data, iterate, log, checkpoint."""
        if self.sft_cfg.dummy_run_full_ctx:
            if self.sft_cfg.resume_from:
                logger.warning("resume_from is ignored in dummy run mode")
            return self._train_dummy()

        tokenized = self.load_dataset()

        # Log tokenized sequence length statistics (once, before training loop)
        self._log_dataset_stats(tokenized)

        # Load eval dataset (if configured). We load once up-front so the
        # tokenization cost is amortized across all eval invocations.
        eval_tokenized = self.load_eval_dataset()
        if eval_tokenized is not None:
            logger.info(f"Eval dataset loaded: {len(eval_tokenized)} examples")

        batch_size = self.sft_cfg.batch_size

        # steps_per_epoch is always derived from the data; callbacks rely on it.
        steps_per_epoch = max(1, ceil(len(tokenized) / batch_size))

        # Resolve num_steps: explicit num_steps takes precedence; otherwise derive from num_epochs.
        if self.sft_cfg.num_steps is not None:
            num_steps = self.sft_cfg.num_steps
        else:
            num_steps = self.sft_cfg.num_epochs * steps_per_epoch
            logger.info(
                f"num_steps not set; deriving from num_epochs={self.sft_cfg.num_epochs}: "
                f"ceil({len(tokenized)} / {batch_size}) * {self.sft_cfg.num_epochs} = {num_steps} steps"
            )

        # Early validation: dataset must have at least batch_size examples
        if len(tokenized) < batch_size:
            raise ValueError(
                f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
                f"Reduce batch_size or use more data."
            )

        self._validate_batch_parallelism()

        # Resume from checkpoint if configured
        start_step = self.load_checkpoint()

        # Shuffle data before training
        rng = random.Random(self.sft_cfg.seed)
        rng.shuffle(tokenized)

        # When resuming, start_step is the last *completed* step (checkpoint is
        # saved AFTER the optimizer update), so we begin at start_step + 1 to
        # avoid replaying that step.

        # Replay epoch shuffles for reproducibility on resume
        start_epoch = (start_step * batch_size) // len(tokenized)
        for _ in range(start_epoch):
            rng.shuffle(tokenized)
        current_epoch = start_epoch

        # Initialize `global_step`
        self.global_step = start_step

        # Publish loop metadata so CallbackInput can be built consistently.
        self._total_steps = num_steps
        self._steps_per_epoch = steps_per_epoch
        self._current_epoch = current_epoch
        self._training_control.reset()

        logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
        if start_step > 0:
            logger.info(f"Resuming from step {start_step}")

        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.start()

        # Tracks whether the most recent in-loop iteration saved a checkpoint
        # (either via the ckpt_interval or via a callback-driven ``should_save``).
        did_save_last_step = False

        self._fire("on_train_start")

        # Baseline eval before training begins (logged at step 0).
        # Wandb's step counter starts at 0; the training loop's first commit
        # advances it to >=1, so step=0 here does not conflict with later steps.
        if self.sft_cfg.eval_before_train and eval_tokenized is not None:
            self._fire("on_eval_start")
            eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
            self._fire("on_eval_end", metrics=eval_metrics)
            baseline_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
            self._fire("on_log", logs=baseline_log)
            self.tracker.log(baseline_log, step=self.global_step, commit=True)
            logger.info(
                f"Baseline eval before training: "
                f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
                f"over {num_eval_batches} batches"
            )

        # SkyRL starts counting at step 1
        self.global_step = start_step + 1 if start_step > 0 else 1
        self._fire("on_epoch_start")
        epoch_in_progress = True

        while self.global_step <= num_steps:
            all_timings: dict[str, float] = {}

            with Timer("step", all_timings):

                # Data loading with wrap-around
                with Timer("data_loading", all_timings):
                    start_idx = (self.global_step * batch_size) % len(tokenized)
                    end_idx = start_idx + batch_size
                    if end_idx > len(tokenized):
                        batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
                    else:
                        batch_examples = tokenized[start_idx:end_idx]
                    batch = self.collator(batch_examples, batch_size=batch_size)

                self._fire("on_step_start", batch=batch)

                # Training step
                step_result = self.train_step(batch, self.global_step)
                all_timings.update(step_result["timings"])

            # Compute throughput using actual (non-padding) tokens
            batch_padded_seq_len = batch["sequences"].shape[1]
            actual_num_tokens = batch["attention_mask"].sum().item()
            self._total_tokens_processed += actual_num_tokens
            tokens_per_second = actual_num_tokens / all_timings["step"]

            # Build log dict
            log_dict = {
                "train/loss": step_result["loss"],
                "train/grad_norm": step_result["grad_norm"],
                "train/tokens_per_second": tokens_per_second,
                "train/tokens_per_second_per_gpu": tokens_per_second / self._num_training_gpus,
                "train/actual_num_tokens": actual_num_tokens,
                "train/batch_padded_seq_len": batch_padded_seq_len,
                "train/total_tokens_processed": self._total_tokens_processed,
            }
            log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
            if self._ray_gpu_monitor is not None:
                log_dict.update(self._ray_gpu_monitor.flush())

            self._fire("on_step_end", batch=batch, metrics=step_result)

            # Capture callback-driven triggers, then reset so they only fire once.
            force_save = self._training_control.should_save
            force_eval = self._training_control.should_evaluate
            self._training_control.should_save = False
            self._training_control.should_evaluate = False

            # Checkpoint: interval-driven or callback-requested.
            interval_save = (
                self.sft_cfg.ckpt_interval > 0
                and self.global_step > 0
                and self.global_step % self.sft_cfg.ckpt_interval == 0
            )
            did_save_last_step = force_save or interval_save
            if did_save_last_step:
                with Timer("save_checkpoint", all_timings):
                    ckpt_path = self.save_checkpoint()
                log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]
                self._fire("on_save", ckpt_path=ckpt_path)

            # HF export at regular intervals
            if self.sft_cfg.hf_save_interval > 0 and self.global_step % self.sft_cfg.hf_save_interval == 0:
                with Timer("save_hf_model", all_timings):
                    self.save_hf_model()
                log_dict["timing/save_hf_model"] = all_timings["save_hf_model"]

            eval_metrics = None
            num_eval_batches: int | None = None
            # Eval fires at step N where N % eval_interval == 0 and N > 0, OR
            # whenever a callback set ``control.should_evaluate``.
            interval_eval = self.sft_cfg.eval_interval > 0 and self.global_step % self.sft_cfg.eval_interval == 0
            if eval_tokenized is not None and (force_eval or interval_eval):
                self._fire("on_eval_start")
                with Timer("eval", all_timings):
                    eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
                self._fire("on_eval_end", metrics=eval_metrics)
                if eval_metrics:
                    log_dict.update({f"eval/{k}": v for k, v in eval_metrics.items()})
                    log_dict["timing/eval"] = all_timings["eval"]

            log_dict.update({"train/epoch": current_epoch, "train/global_step": self.global_step})
            # Callbacks may mutate log_dict in place via on_log.
            self._fire("on_log", logs=log_dict)
            self.tracker.log(log_dict, step=self.global_step, commit=True)

            if self.global_step % 5 == 0:
                logger.info(
                    f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
                )

            if eval_metrics:
                logger.info(
                    f"Step {self.global_step}: eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
                    f"over {num_eval_batches} batches"
                )

            # Check for epoch boundary and reshuffle
            epoch = (self.global_step * batch_size) // len(tokenized)
            if epoch > current_epoch:
                self._fire("on_epoch_end")
                epoch_in_progress = False
                for _ in range(epoch - current_epoch):
                    rng.shuffle(tokenized)
                current_epoch = epoch
                self._current_epoch = epoch
                if self.global_step + 1 <= num_steps:
                    self._fire("on_epoch_start")
                    epoch_in_progress = True

            self.global_step += 1
        self.global_step = min(self.global_step, num_steps)

        # Pair the leading on_epoch_start: fire on_epoch_end if we exited the
        # loop mid-epoch
        if epoch_in_progress:
            self._fire("on_epoch_end")
            epoch_in_progress = False

        # Save final checkpoint (if checkpointing is enabled). Skip if the last
        # in-loop iteration already saved (either via ckpt_interval or via a
        # callback-driven force-save) so we don't double-save.
        if self.sft_cfg.ckpt_path and not did_save_last_step:
            final_step = num_steps
            logger.info(f"Saving final checkpoint at step {final_step}")
            ckpt_path = self.save_checkpoint()
            self._fire("on_save", ckpt_path=ckpt_path)

        # Save final HF model if enabled (only if not already saved at last step)
        if self.sft_cfg.hf_save_interval > 0:
            final_step = num_steps
            already_saved = final_step % self.sft_cfg.hf_save_interval == 0
            if not already_saved:
                self.global_step = final_step
                logger.info(f"Saving final HF model at step {final_step}")
                self.save_hf_model()

        # Final eval pass (skip if the last step already ran eval).
        # NOTE: The last in-loop tracker.log(..., commit=True) at step=num_steps
        # advanced wandb's internal step counter to num_steps+1. Logging the
        # final eval at step=num_steps would be rejected by wandb with
        # "step N < current step N+1". We log the final eval at num_steps+1
        # (one past the last committed train step) in a single combined
        # tracker.log() call, preserving wandb step ordering. We use a local
        # ``final_eval_step`` rather than mutating ``self.global_step``: the
        # bump is purely a wandb-step accounting concern, not real trainer
        # state.
        if eval_tokenized is not None:
            already_ran = self.sft_cfg.eval_interval > 0 and num_steps % self.sft_cfg.eval_interval == 0
            if not already_ran:
                final_eval_step = num_steps + 1
                eval_timings: dict[str, float] = {}
                self._fire("on_eval_start")
                with Timer("eval", eval_timings):
                    eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
                self._fire("on_eval_end", metrics=eval_metrics)
                if eval_metrics:
                    eval_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
                    eval_log["timing/eval"] = eval_timings["eval"]
                    self._fire("on_log", logs=eval_log)
                    self.tracker.log(eval_log, step=final_eval_step, commit=True)
                    logger.info(
                        f"Final eval at step {final_eval_step}: "
                        f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
                        f"over {num_eval_batches} batches"
                    )

        self._fire("on_train_end")
        logger.info("SFT training complete!")

method abstractmethod save_checkpoint

save_checkpoint() -> str

Save a checkpoint at the given step. Returns the checkpoint folder path.

Source code in skyrl/train/sft_trainer.py:1696-1723
    def save_checkpoint(self) -> str:
        """Save a checkpoint at the given step. Returns the checkpoint folder path."""
        step = self.global_step
        global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        io.makedirs(global_step_folder, exist_ok=True)
        logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
        trainer_state = {
            "global_step": step,
            "config": asdict(self.sft_cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking -- write this last after all saves succeed
        latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_file, "w") as f:
            f.write(str(step))
        logger.info(f"Checkpoint saved for global_step_{step}")

        # Clean up old checkpoints after successful save
        cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)
        return global_step_folder

method save_hf_model

save_hf_model()

Save policy weights in HuggingFace format.

Export path: cfg.trainer.export_path/global_step_{step}/policy Mirrors the pattern used by the RL trainer's save_models().

Source code in skyrl/train/sft_trainer.py:1725-1738
    def save_hf_model(self):
        """Save policy weights in HuggingFace format.

        Export path: cfg.trainer.export_path/global_step_{step}/policy
        Mirrors the pattern used by the RL trainer's save_models().
        """
        step = self.global_step
        policy_export_dir = os.path.join(
            self.cfg.trainer.export_path,
            f"{GLOBAL_STEP_PREFIX}{step}",
            "policy",
        )
        self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer)
        logger.info(f"Saved HF model weights at step {step} to {policy_export_dir}")

method shutdown

shutdown()

Finish tracking.

Does NOT call ray.shutdown() -- when running inside a Ray task (the normal path via sft_entrypoint), shutting down Ray from within the task would be incorrect. The head-node process owns the Ray lifecycle.

Source code in skyrl/train/sft_trainer.py:1744-1755
    def shutdown(self):
        """Finish tracking.

        Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
        (the normal path via ``sft_entrypoint``), shutting down Ray from
        within the task would be incorrect.  The head-node process owns
        the Ray lifecycle.
        """
        if self._ray_gpu_monitor is not None:
            self._ray_gpu_monitor.stop()
        if self.tracker is not None:
            self.tracker.finish()

Config Bridge

method validate_sft_cfg

validate_sft_cfg(cfg: SFTConfig) -> None

Validate SFT-specific configuration.

Only checks fields that are relevant to SFT training, unlike validate_cfg which includes RL-specific validations.

method build_skyrl_config_for_sft

build_skyrl_config_for_sft(sft_cfg: SFTConfig) -> SkyRLTrainConfig

Map user-facing SFTConfig to the internal SkyRL backend config.

On this page