SkyRL
API ReferenceSkyRL

SFT

Supervised Fine-Tuning configuration and trainer.

Configuration

class SFTPlacementConfig

SFTPlacementConfig(num_nodes: int = 1, num_gpus_per_node: int = 4) -> None

Bases: BaseConfig

Placement configuration for SFT training

Functions:

NameDescription
from_dict_configConstruct a typed BaseConfig from a Hydra DictConfig.

Attributes:

NameTypeDescription
num_nodesint
num_gpus_per_nodeint
Source code in skyrl/train/config/sft_config.py:28-33
@dataclass
class SFTPlacementConfig(BaseConfig):
    """Placement configuration for SFT training"""

    num_nodes: int = 1
    num_gpus_per_node: int = 4

attr num_nodes

num_nodes: int = 1

from_dict_config

from_dict_config(cfg: DictConfig) -> BaseConfig

Construct a typed BaseConfig from a Hydra DictConfig.

attr num_gpus_per_node

num_gpus_per_node: int = 4

class SFTConfig

SFTConfig(model: ModelConfig = (lambda: ModelConfig(path='Qwen/Qwen3-0.6B'))(), optimizer_config: OptimizerConfig = OptimizerConfig(), placement: SFTPlacementConfig = SFTPlacementConfig(), megatron_config: MegatronConfig = (lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2))(), fsdp_config: FSDPConfig = FSDPConfig(), sequence_parallel_size: int = 1, model_config_kwargs: dict = dict(), use_torch_compile: bool = False, record_memory: bool = False, strategy: str = 'megatron', dataset_name: str = 'yahma/alpaca-cleaned', dataset_split: str = 'train[:100]', messages_key: str = 'messages', max_length: int = 512, num_steps: int = 10, batch_size: int = 4, micro_train_batch_size_per_gpu: int = 2, logger: str = 'console', project_name: str = 'skyrl_sft', run_name: str = 'skyrl_sft_run', ckpt_path: str = '', ckpt_interval: int = 0, max_ckpts_to_keep: int = -1, resume_from: str = '', seed: int = 42, use_sample_packing: bool = True, dummy_run_full_ctx: bool = False, dummy_run_max_steps: int = 5) -> None

Bases: BaseConfig

Configuration for SFT training.

Usage::

cfg = SFTConfig( strategy="megatron", placement=SFTPlacementConfig(num_gpus_per_node=4), megatron_config=MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2), )

Or from CLI::

cfg = SFTConfig.from_cli_overrides(sys.argv[1:])

Functions:

NameDescription
from_dict_configConstruct a typed BaseConfig from a Hydra DictConfig.
from_cli_overridesConstruct an SFTConfig from CLI arguments or a dict of overrides.

Attributes:

NameTypeDescription
modelModelConfig
optimizer_configOptimizerConfig
placementSFTPlacementConfig
megatron_configMegatronConfig
fsdp_configFSDPConfig
sequence_parallel_sizeintUlysses sequence parallelism size
model_config_kwargsdictPass-through kwargs for the HuggingFace model config (FSDP backends).
use_torch_compileboolApply torch.compile to logits calculation.
record_memoryboolSave memory snapshots to {ckpt_path}/memory_snapshots/.
strategystr
dataset_namestr
dataset_splitstr
messages_keystr
max_lengthint
num_stepsint
batch_sizeint
micro_train_batch_size_per_gpuint
loggerstr
project_namestr
run_namestr
ckpt_pathstr
ckpt_intervalint
max_ckpts_to_keepint-1 to keep all checkpoints, N to keep only the last N.
resume_fromstr
seedint
use_sample_packingbool
dummy_run_full_ctxbool
dummy_run_max_stepsint
Source code in skyrl/train/config/sft_config.py:36-125
@dataclass
class SFTConfig(BaseConfig):
    """Configuration for SFT training.

    Usage::

        cfg = SFTConfig(
            strategy="megatron",
            placement=SFTPlacementConfig(num_gpus_per_node=4),
            megatron_config=MegatronConfig(tensor_model_parallel_size=2,
                                    pipeline_model_parallel_size=2),
        )

    Or from CLI::

        cfg = SFTConfig.from_cli_overrides(sys.argv[1:])
    """

    @classmethod
    def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
        """Construct an SFTConfig from CLI arguments or a dict of overrides.

        Parses CLI dotlist arguments via OmegaConf and builds a typed config.
        Dataclass field defaults are used for any values not specified.

        Args:
            args: Either a list of CLI arguments in 'key.path=value' format, or a dict
                  mapping dot-notation keys to values.
                  Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
                  Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}

        Returns:
            A fully constructed SFTConfig with CLI overrides applied.
        """
        if isinstance(args, dict):
            args = [f"{k}={v}" for k, v in args.items()]

        overrides = OmegaConf.from_cli(args)
        return cls.from_dict_config(overrides)

    # ---- Reused SkyRL config objects ----
    model: ModelConfig = field(default_factory=lambda: ModelConfig(path="Qwen/Qwen3-0.6B"))
    optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
    placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)
    megatron_config: MegatronConfig = field(
        default_factory=lambda: MegatronConfig(
            tensor_model_parallel_size=2,
            pipeline_model_parallel_size=2,
        )
    )
    fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)

    # Ulysses sequence parallelism
    sequence_parallel_size: int = 1
    """Ulysses sequence parallelism size"""

    model_config_kwargs: dict = field(default_factory=dict)
    """Pass-through kwargs for the HuggingFace model config (FSDP backends).
    For Megatron, use ``megatron_config.transformer_config_kwargs`` instead."""
    use_torch_compile: bool = False
    """Apply torch.compile to logits calculation."""
    record_memory: bool = False
    """Save memory snapshots to ``{ckpt_path}/memory_snapshots/``.
    Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz."""

    # ---- SFT-specific flat fields ----
    strategy: str = "megatron"  # "megatron" or "fsdp2"
    dataset_name: str = "yahma/alpaca-cleaned"
    dataset_split: str = "train[:100]"
    messages_key: str = "messages"  # column name for chat-format datasets
    max_length: int = 512
    num_steps: int = 10
    batch_size: int = 4
    micro_train_batch_size_per_gpu: int = 2
    logger: str = "console"  # "console" or "wandb"
    project_name: str = "skyrl_sft"
    run_name: str = "skyrl_sft_run"
    ckpt_path: str = ""  # empty string = no checkpointing
    ckpt_interval: int = 0
    max_ckpts_to_keep: int = -1
    """-1 to keep all checkpoints, N to keep only the last N."""
    resume_from: str = ""  # "" = no resume, "latest" = latest checkpoint, or path to global_step_N dir
    seed: int = 42

    # ---- Packing ----
    use_sample_packing: bool = True  # Pack multiple sequences per batch (requires flash_attn)

    # ---- Dummy run / benchmarking ----
    dummy_run_full_ctx: bool = False  # Skip real data; fabricate full-context sequences
    dummy_run_max_steps: int = 5  # Number of steps to run in dummy mode

from_dict_config

from_dict_config(cfg: DictConfig) -> BaseConfig

Construct a typed BaseConfig from a Hydra DictConfig.

method classmethod from_cli_overrides

from_cli_overrides(args: Union[List[str], dict]) -> SFTConfig

Construct an SFTConfig from CLI arguments or a dict of overrides.

Parses CLI dotlist arguments via OmegaConf and builds a typed config. Dataclass field defaults are used for any values not specified.

Parameters:

NameTypeDescriptionDefault
argsUnion[List[str], dict]Either a list of CLI arguments in 'key.path=value' format, or a dict mapping dot-notation keys to values. Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B'] Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}required

Returns:

TypeDescription
SFTConfigA fully constructed SFTConfig with CLI overrides applied.
Source code in skyrl/train/config/sft_config.py:54-74
    @classmethod
    def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
        """Construct an SFTConfig from CLI arguments or a dict of overrides.

        Parses CLI dotlist arguments via OmegaConf and builds a typed config.
        Dataclass field defaults are used for any values not specified.

        Args:
            args: Either a list of CLI arguments in 'key.path=value' format, or a dict
                  mapping dot-notation keys to values.
                  Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
                  Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}

        Returns:
            A fully constructed SFTConfig with CLI overrides applied.
        """
        if isinstance(args, dict):
            args = [f"{k}={v}" for k, v in args.items()]

        overrides = OmegaConf.from_cli(args)
        return cls.from_dict_config(overrides)

attr model

model: ModelConfig = field(default_factory=(lambda: ModelConfig(path='Qwen/Qwen3-0.6B')))

attr optimizer_config

optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)

attr placement

placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)

attr megatron_config

megatron_config: MegatronConfig = field(default_factory=(lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2)))

attr fsdp_config

fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)

attr sequence_parallel_size

sequence_parallel_size: int = 1

Ulysses sequence parallelism size

attr model_config_kwargs

model_config_kwargs: dict = field(default_factory=dict)

Pass-through kwargs for the HuggingFace model config (FSDP backends). For Megatron, use megatron_config.transformer_config_kwargs instead.

attr use_torch_compile

use_torch_compile: bool = False

Apply torch.compile to logits calculation.

attr record_memory

record_memory: bool = False

Save memory snapshots to {ckpt_path}/memory_snapshots/. Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz.

attr strategy

strategy: str = 'megatron'

attr dataset_name

dataset_name: str = 'yahma/alpaca-cleaned'

attr dataset_split

dataset_split: str = 'train[:100]'

attr messages_key

messages_key: str = 'messages'

attr max_length

max_length: int = 512

attr num_steps

num_steps: int = 10

attr batch_size

batch_size: int = 4

attr micro_train_batch_size_per_gpu

micro_train_batch_size_per_gpu: int = 2

attr logger

logger: str = 'console'

attr project_name

project_name: str = 'skyrl_sft'

attr run_name

run_name: str = 'skyrl_sft_run'

attr ckpt_path

ckpt_path: str = ''

attr ckpt_interval

ckpt_interval: int = 0

attr max_ckpts_to_keep

max_ckpts_to_keep: int = -1

-1 to keep all checkpoints, N to keep only the last N.

attr resume_from

resume_from: str = ''

attr seed

seed: int = 42

attr use_sample_packing

use_sample_packing: bool = True

attr dummy_run_full_ctx

dummy_run_full_ctx: bool = False

attr dummy_run_max_steps

dummy_run_max_steps: int = 5

Trainer

class SFTTrainer

SFTTrainer(cfg: SFTConfig, skyrl_cfg: SkyRLTrainConfig | None = None)

SFT trainer supporting FSDP and Megatron backends.

Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are fundamentally different: no generation, no critic, no advantages, no KL penalty. Sharing a base class would create confusing dead code paths.

Usage::

trainer = SFTTrainer(SFTConfig(strategy="megatron")) trainer.setup() trainer.train() trainer.shutdown()

Functions:

NameDescription
setupInitialize tokenizer, workers, dispatch, and tracker.
load_datasetLoad and tokenize the training dataset.
collate_batchCollate examples into a TrainingInputBatch with loss normalization.
load_checkpointLoad a checkpoint and return the step number to resume from.
train_stepExecute a single training step: forward_backward + optim_step.
trainFull training loop: load data, iterate, log, checkpoint.
save_checkpointSave a checkpoint at the given step.
shutdownFinish tracking.

Attributes:

NameTypeDescription
sft_cfg
cfg
tokenizer
dispatchWorkerDispatch | None
trackerTracking | None
global_step
Source code in skyrl/train/sft_trainer.py:193-716
class SFTTrainer:
    """SFT trainer supporting FSDP and Megatron backends.

    Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are
    fundamentally different: no generation, no critic, no advantages, no
    KL penalty. Sharing a base class would create confusing dead code paths.

    Usage::

        trainer = SFTTrainer(SFTConfig(strategy="megatron"))
        trainer.setup()
        trainer.train()
        trainer.shutdown()
    """

    def __init__(self, cfg: SFTConfig, skyrl_cfg: SkyRLTrainConfig | None = None):
        self.sft_cfg = cfg
        # Accept a pre-built bridge config to avoid redundant rebuilds.
        # When not provided (e.g. standalone usage), build it here.
        self.cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)
        self.tokenizer = None
        self.dispatch: WorkerDispatch | None = None
        self.tracker: Tracking | None = None
        self.global_step = 0

    # ------------------------------------------------------------------ #
    # Setup
    # ------------------------------------------------------------------ #

    def setup(self):
        """Initialize tokenizer, workers, dispatch, and tracker.

        Ray must already be initialized before calling this (either via
        ``initialize_ray`` on the head node or inside a Ray task).
        """
        self.tokenizer = get_tokenizer(
            self.cfg.trainer.policy.model.path,
            trust_remote_code=True,
            use_fast=not self.cfg.trainer.disable_fast_tokenizer,
            padding_side="left",
        )
        self._init_workers()
        self._init_tracker()

    def _init_workers(self):
        """Create PPORayActorGroup and WorkerDispatch.

        Selects the correct PolicyWorker based on strategy.
        """
        if self.sft_cfg.strategy == "megatron":
            from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
                PolicyWorker,
            )
        else:
            from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker

        num_gpus = self.sft_cfg.placement.num_gpus_per_node
        raw_pg = placement_group(
            [{"GPU": num_gpus, "CPU": num_gpus}] * self.sft_cfg.placement.num_nodes,
            strategy="PACK",
        )
        get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
        pg = ResolvedPlacementGroup(raw_pg)

        actor_group = PPORayActorGroup(
            self.cfg.trainer,
            num_nodes=self.sft_cfg.placement.num_nodes,
            num_gpus_per_node=num_gpus,
            ray_actor_type=PolicyWorker,
            pg=pg,
            num_gpus_per_actor=1,
            colocate_all=False,
            sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size,
            record_memory=self.cfg.trainer.policy.record_memory,
        )
        num_training_steps = (
            self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps
        )
        ray.get(
            actor_group.async_init_model(
                self.sft_cfg.model.path,
                num_training_steps=num_training_steps,
            )
        )
        ray.get(actor_group.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))

        self.dispatch = WorkerDispatch(self.cfg, policy_actor_group=actor_group)

    def _init_tracker(self):
        self.tracker = Tracking(
            project_name=self.cfg.trainer.project_name,
            experiment_name=self.cfg.trainer.run_name,
            backends=self.cfg.trainer.logger,
            config=self.sft_cfg,
        )

    # ------------------------------------------------------------------ #
    # Data
    # ------------------------------------------------------------------ #

    def _load_and_tokenize(self, dataset_name: str, dataset_split: str) -> list:
        """Load and tokenize a dataset.

        Auto-detects the dataset format based on column names:
        - If a ``messages_key`` column exists, uses chat-format tokenization.
        - If ``instruction`` and ``output`` columns exist, uses Alpaca-format
          tokenization.

        Args:
            dataset_name: HuggingFace dataset name (e.g. ``"yahma/alpaca-cleaned"``).
            dataset_split: Dataset split (e.g. ``"train[:100]"`` or ``"test"``).

        Returns a list of tokenized examples (dicts with ``input_ids``,
        ``attention_mask``, ``num_actions``).
        """
        logger.info(f"Loading dataset '{dataset_name}' split='{dataset_split}'...")
        dataset = load_dataset(dataset_name, split=dataset_split)

        columns = dataset.column_names
        logger.info("Tokenizing dataset...")

        if self.sft_cfg.messages_key in columns:
            # Chat format
            tokenized = [
                tokenize_chat_example(ex, self.tokenizer, self.sft_cfg.max_length, self.sft_cfg.messages_key)
                for ex in dataset
            ]
        elif "instruction" in columns and "output" in columns:
            # Alpaca format
            tokenized = [tokenize_sft_example(ex, self.tokenizer, self.sft_cfg.max_length) for ex in dataset]
        else:
            raise ValueError(
                f"Unrecognized dataset format. Expected '{self.sft_cfg.messages_key}' column "
                f"(chat format) or 'instruction'+'output' columns (Alpaca format). "
                f"Found columns: {columns}"
            )

        tokenized = [ex for ex in tokenized if ex is not None]
        logger.info(f"Tokenized {len(tokenized)} examples (filtered from {len(dataset)})")
        return tokenized

    def load_dataset(self) -> list:
        """Load and tokenize the training dataset."""
        return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)

    def collate_batch(self, examples: list) -> TrainingInputBatch:
        """Collate examples into a TrainingInputBatch with loss normalization.

        Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss
        produces a per-non-pad-token mean, matching the standard convention.

        NOTE: The scaling factor is ``batch_size / (micro_batch_size * total_nonpad)``
        where ``total_nonpad`` is the count of non-masked (loss-contributing)
        tokens in the full batch.  This accounts for the ``microbatch_weight``
        (FSDP) or ``1/num_microbatches`` (Megatron) applied during gradient
        accumulation so that the effective gradient equals
        ``d[sum(-log_probs_on_nonpad) / total_nonpad]``.
        """
        batch = collate_sft_batch(examples, self.tokenizer)
        # Loss normalization: divide by non-pad token count (not padded seq length)
        # NOTE (sumanthrh): This specific scaling factor is because SkyRL's workers internally normalize
        # by number of micro batches, but aggregate otherwise
        micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
        total_nonpad = max(batch["loss_mask"].sum().item(), 1)
        batch["loss_mask"] = batch["loss_mask"].float() * (self.sft_cfg.batch_size / (micro_batch_size * total_nonpad))
        return batch

    # ------------------------------------------------------------------ #
    # Checkpoint resume
    # ------------------------------------------------------------------ #

    def load_checkpoint(self) -> int:
        """Load a checkpoint and return the step number to resume from.

        Behaviour depends on ``sft_cfg.resume_from``:
        - ``""`` (empty): no resume, return 0.
        - ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
        - otherwise: treat as a direct path to a ``global_step_N`` directory.

        Returns:
            The global step to resume from (0 if no checkpoint loaded).
        """
        resume_from = self.sft_cfg.resume_from
        if not resume_from:
            return 0

        if resume_from == "latest":
            if not self.sft_cfg.ckpt_path:
                logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
                return 0
            latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_file):
                logger.info("No latest checkpoint marker found, starting from scratch")
                return 0
            with io.open_file(latest_file, "r") as f:
                ckpt_step = int(f.read().strip())
            checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
            # Validate consistency: ensure no stale checkpoint folders from prior runs
            validate_consistency_for_latest_checkpoint(
                self.sft_cfg.ckpt_path,
                ckpt_step,
                checkpoint_path,
                latest_file,
                self.sft_cfg.ckpt_interval,
            )
        else:
            checkpoint_path = resume_from

        if not io.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        global_step = extract_step_from_path(checkpoint_path)
        if global_step == -1:
            raise ValueError(
                f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
                f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
            )

        # Load and validate trainer state if available
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        if io.exists(trainer_state_path):
            with io.open_file(trainer_state_path, "rb") as f:
                trainer_state = torch.load(f, map_location="cpu", weights_only=False)
            saved_global_step = trainer_state.get("global_step", global_step)
            logger.info("Successfully loaded trainer state")
            if saved_global_step != global_step:
                logger.warning(
                    f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
                )
        else:
            logger.warning(
                f"No trainer_state.pt found at {trainer_state_path}. "
                "This checkpoint was likely saved by an older version."
            )

        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info(f"Successfully resumed from global_step_{global_step}")
        return global_step

    # ------------------------------------------------------------------ #
    # Training
    # ------------------------------------------------------------------ #

    def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
        """Execute a single training step: forward_backward + optim_step.

        Args:
            batch: The collated training batch.
            step: Current global step (reserved for future use, e.g. scheduling).

        Returns:
            Dict with ``loss``, ``grad_norm``, and ``timings``.
        """
        timings: dict[str, float] = {}
        with Timer("forward_backward", timings):
            metrics = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
        with Timer("optim_step", timings):
            grad_norm = self.dispatch.optim_step("policy")

        loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
        return {
            "loss": loss_val,
            "grad_norm": grad_norm,
            "timings": timings,
        }

    def _validate_batch_parallelism(self):
        """Validate that batch_size is compatible with data-parallel and micro-batch sizes."""
        batch_size = self.sft_cfg.batch_size
        total_gpus = self.sft_cfg.placement.num_nodes * self.sft_cfg.placement.num_gpus_per_node
        if self.sft_cfg.strategy == "megatron":
            tp = self.sft_cfg.megatron_config.tensor_model_parallel_size
            pp = self.sft_cfg.megatron_config.pipeline_model_parallel_size
            dp_size = total_gpus // (tp * pp)
        else:
            # FSDP: all GPUs are data-parallel
            dp_size = total_gpus
        if batch_size % dp_size != 0:
            raise ValueError(f"batch_size ({batch_size}) must be divisible by data-parallel size ({dp_size})")
        per_dp_batch = batch_size // dp_size
        micro_batch = self.sft_cfg.micro_train_batch_size_per_gpu
        if per_dp_batch % micro_batch != 0:
            raise ValueError(
                f"batch_size / dp_size ({per_dp_batch}) must be divisible by "
                f"micro_train_batch_size_per_gpu ({micro_batch})"
            )

    def _build_dummy_batch(self) -> TrainingInputBatch:
        """Build a dummy batch of random full-context sequences for benchmarking."""
        batch_size = self.sft_cfg.batch_size
        max_length = self.sft_cfg.max_length
        micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
        vocab_size = self.tokenizer.vocab_size

        # num_actions is max_length - 1 because the autoregressive model
        # produces log-probs for positions 1..T (predicting next token),
        # so the first token has no corresponding log-prob.
        num_actions = max_length - 1

        sequences = torch.randint(0, vocab_size, (batch_size, max_length), dtype=torch.long)
        attention_mask = torch.ones(batch_size, max_length, dtype=torch.long)
        # All tokens are non-pad in the dummy batch, so total_nonpad = batch_size * num_actions.
        # Scaling = batch_size / (micro_batch_size * total_nonpad)
        #         = 1 / (micro_batch_size * num_actions)
        total_nonpad = batch_size * num_actions
        loss_mask = torch.ones(batch_size, num_actions, dtype=torch.float) * (
            batch_size / (micro_batch_size * total_nonpad)
        )

        batch = TrainingInputBatch(
            {
                "sequences": sequences,
                "attention_mask": attention_mask,
                "loss_mask": loss_mask,
            }
        )
        batch.metadata = {"response_length": num_actions}
        return batch

    def _train_dummy(self):
        """Dummy training loop for benchmarking. Skips real data, checkpoints, and resume."""
        self._validate_batch_parallelism()
        batch = self._build_dummy_batch()
        num_steps = self.sft_cfg.dummy_run_max_steps

        logger.info(
            f"Starting dummy SFT training for {num_steps} steps "
            f"(batch_size={self.sft_cfg.batch_size}, max_length={self.sft_cfg.max_length})..."
        )

        for step in range(num_steps):
            all_timings: dict[str, float] = {}

            with Timer("step", all_timings):
                step_result = self.train_step(batch, step)
                all_timings.update(step_result["timings"])

            actual_num_tokens = batch["attention_mask"].sum().item()
            tokens_per_second = actual_num_tokens / all_timings["step"]

            log_dict = {
                "train/loss": step_result["loss"],
                "train/grad_norm": step_result["grad_norm"],
                "train/tokens_per_second": tokens_per_second,
                "train/actual_num_tokens": actual_num_tokens,
            }
            log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})

            self.tracker.log(log_dict, step=step, commit=True)
            logger.info(
                f"Step {step}: loss={step_result['loss']:.4f}, "
                f"grad_norm={step_result['grad_norm']}, "
                f"tokens_per_second={tokens_per_second:.0f}"
            )

        logger.info("Dummy SFT training complete!")

    def train(self):
        """Full training loop: load data, iterate, log, checkpoint."""
        if self.sft_cfg.dummy_run_full_ctx:
            if self.sft_cfg.resume_from:
                logger.warning("resume_from is ignored in dummy run mode")
            return self._train_dummy()

        tokenized = self.load_dataset()

        batch_size = self.sft_cfg.batch_size
        num_steps = self.sft_cfg.num_steps

        # Early validation: dataset must have at least batch_size examples
        if len(tokenized) < batch_size:
            raise ValueError(
                f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
                f"Reduce batch_size or use more data."
            )

        self._validate_batch_parallelism()

        # Resume from checkpoint if configured
        start_step = self.load_checkpoint()

        # Shuffle data before training
        rng = random.Random(self.sft_cfg.seed)
        rng.shuffle(tokenized)

        # When resuming, start_step is the last *completed* step (checkpoint is
        # saved AFTER the optimizer update), so we begin at start_step + 1 to
        # avoid replaying that step.

        # Replay epoch shuffles for reproducibility on resume
        start_epoch = (start_step * batch_size) // len(tokenized)
        for _ in range(start_epoch):
            rng.shuffle(tokenized)
        current_epoch = start_epoch

        # SkyRL starts counting at step 1
        self.global_step = start_step + 1 if start_step > 0 else 1

        logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
        if start_step > 0:
            logger.info(f"Resuming from step {start_step}")
        while self.global_step <= num_steps:
            all_timings: dict[str, float] = {}

            with Timer("step", all_timings):

                # Data loading with wrap-around
                with Timer("data_loading", all_timings):
                    start_idx = (self.global_step * batch_size) % len(tokenized)
                    end_idx = start_idx + batch_size
                    if end_idx > len(tokenized):
                        batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
                    else:
                        batch_examples = tokenized[start_idx:end_idx]
                    batch = self.collate_batch(batch_examples)

                # Training step
                step_result = self.train_step(batch, self.global_step)
                all_timings.update(step_result["timings"])

            # Compute throughput using actual (non-padding) tokens
            batch_padded_seq_len = batch["sequences"].shape[1]
            actual_num_tokens = batch["attention_mask"].sum().item()
            tokens_per_second = actual_num_tokens / all_timings["step"]

            # Build log dict
            log_dict = {
                "train/loss": step_result["loss"],
                "train/grad_norm": step_result["grad_norm"],
                "train/tokens_per_second": tokens_per_second,
                "train/actual_num_tokens": actual_num_tokens,
                "train/batch_padded_seq_len": batch_padded_seq_len,
            }
            log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})

            # Checkpoint at regular intervals
            if (
                self.sft_cfg.ckpt_path
                and self.sft_cfg.ckpt_interval > 0
                and self.global_step > 0
                and self.global_step % self.sft_cfg.ckpt_interval == 0
            ):
                with Timer("save_checkpoint", all_timings):
                    self.save_checkpoint()
                log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]

            self.tracker.log(log_dict, step=self.global_step, commit=True)

            if self.global_step % 5 == 0:
                logger.info(
                    f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
                )

            # Check for epoch boundary and reshuffle
            epoch = (self.global_step * batch_size) // len(tokenized)
            if epoch > current_epoch:
                for _ in range(epoch - current_epoch):
                    rng.shuffle(tokenized)
                current_epoch = epoch

            self.global_step += 1
        self.global_step = min(self.global_step, num_steps)

        # Save final checkpoint (if checkpointing is enabled)
        if self.sft_cfg.ckpt_path:
            final_step = num_steps
            already_saved = (
                self.sft_cfg.ckpt_interval > 0 and final_step > 0 and final_step % self.sft_cfg.ckpt_interval == 0
            )
            if not already_saved:
                logger.info(f"Saving final checkpoint at step {final_step}")
                self.save_checkpoint()

        logger.info("SFT training complete!")

    def save_checkpoint(self):
        """Save a checkpoint at the given step."""
        step = self.global_step
        global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        io.makedirs(global_step_folder, exist_ok=True)
        logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
        trainer_state = {
            "global_step": step,
            "config": asdict(self.sft_cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking -- write this last after all saves succeed
        latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_file, "w") as f:
            f.write(str(step))
        logger.info(f"Checkpoint saved for global_step_{step}")

        # Clean up old checkpoints after successful save
        cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)

    # ------------------------------------------------------------------ #
    # Lifecycle
    # ------------------------------------------------------------------ #

    def shutdown(self):
        """Finish tracking.

        Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
        (the normal path via ``sft_entrypoint``), shutting down Ray from
        within the task would be incorrect.  The head-node process owns
        the Ray lifecycle.
        """
        if self.tracker is not None:
            self.tracker.finish()

attr sft_cfg

sft_cfg = cfg

attr cfg

cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)

attr tokenizer

tokenizer = None

attr dispatch

dispatch: WorkerDispatch | None = None

attr tracker

tracker: Tracking | None = None

attr global_step

global_step = 0

method setup

setup()

Initialize tokenizer, workers, dispatch, and tracker.

Ray must already be initialized before calling this (either via initialize_ray on the head node or inside a Ray task).

Source code in skyrl/train/sft_trainer.py:222-235
    def setup(self):
        """Initialize tokenizer, workers, dispatch, and tracker.

        Ray must already be initialized before calling this (either via
        ``initialize_ray`` on the head node or inside a Ray task).
        """
        self.tokenizer = get_tokenizer(
            self.cfg.trainer.policy.model.path,
            trust_remote_code=True,
            use_fast=not self.cfg.trainer.disable_fast_tokenizer,
            padding_side="left",
        )
        self._init_workers()
        self._init_tracker()

method load_dataset

load_dataset() -> list

Load and tokenize the training dataset.

Source code in skyrl/train/sft_trainer.py:334-336
    def load_dataset(self) -> list:
        """Load and tokenize the training dataset."""
        return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)

method collate_batch

collate_batch(examples: list) -> TrainingInputBatch

Collate examples into a TrainingInputBatch with loss normalization.

Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss produces a per-non-pad-token mean, matching the standard convention.

NOTE: The scaling factor is batch_size / (micro_batch_size * total_nonpad) where total_nonpad is the count of non-masked (loss-contributing) tokens in the full batch. This accounts for the microbatch_weight (FSDP) or 1/num_microbatches (Megatron) applied during gradient accumulation so that the effective gradient equals d[sum(-log_probs_on_nonpad) / total_nonpad].

Source code in skyrl/train/sft_trainer.py:338-358
    def collate_batch(self, examples: list) -> TrainingInputBatch:
        """Collate examples into a TrainingInputBatch with loss normalization.

        Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss
        produces a per-non-pad-token mean, matching the standard convention.

        NOTE: The scaling factor is ``batch_size / (micro_batch_size * total_nonpad)``
        where ``total_nonpad`` is the count of non-masked (loss-contributing)
        tokens in the full batch.  This accounts for the ``microbatch_weight``
        (FSDP) or ``1/num_microbatches`` (Megatron) applied during gradient
        accumulation so that the effective gradient equals
        ``d[sum(-log_probs_on_nonpad) / total_nonpad]``.
        """
        batch = collate_sft_batch(examples, self.tokenizer)
        # Loss normalization: divide by non-pad token count (not padded seq length)
        # NOTE (sumanthrh): This specific scaling factor is because SkyRL's workers internally normalize
        # by number of micro batches, but aggregate otherwise
        micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
        total_nonpad = max(batch["loss_mask"].sum().item(), 1)
        batch["loss_mask"] = batch["loss_mask"].float() * (self.sft_cfg.batch_size / (micro_batch_size * total_nonpad))
        return batch

method abstractmethod load_checkpoint

load_checkpoint() -> int

Load a checkpoint and return the step number to resume from.

Behaviour depends on sft_cfg.resume_from:

  • "" (empty): no resume, return 0.
  • "latest": read latest_ckpt_global_step.txt from ckpt_path.
  • otherwise: treat as a direct path to a global_step_N directory.

Returns:

TypeDescription
intThe global step to resume from (0 if no checkpoint loaded).
Source code in skyrl/train/sft_trainer.py:364-437
    def load_checkpoint(self) -> int:
        """Load a checkpoint and return the step number to resume from.

        Behaviour depends on ``sft_cfg.resume_from``:
        - ``""`` (empty): no resume, return 0.
        - ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
        - otherwise: treat as a direct path to a ``global_step_N`` directory.

        Returns:
            The global step to resume from (0 if no checkpoint loaded).
        """
        resume_from = self.sft_cfg.resume_from
        if not resume_from:
            return 0

        if resume_from == "latest":
            if not self.sft_cfg.ckpt_path:
                logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
                return 0
            latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
            if not io.exists(latest_file):
                logger.info("No latest checkpoint marker found, starting from scratch")
                return 0
            with io.open_file(latest_file, "r") as f:
                ckpt_step = int(f.read().strip())
            checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
            # Validate consistency: ensure no stale checkpoint folders from prior runs
            validate_consistency_for_latest_checkpoint(
                self.sft_cfg.ckpt_path,
                ckpt_step,
                checkpoint_path,
                latest_file,
                self.sft_cfg.ckpt_interval,
            )
        else:
            checkpoint_path = resume_from

        if not io.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")

        global_step = extract_step_from_path(checkpoint_path)
        if global_step == -1:
            raise ValueError(
                f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
                f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
            )

        # Load and validate trainer state if available
        trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
        if io.exists(trainer_state_path):
            with io.open_file(trainer_state_path, "rb") as f:
                trainer_state = torch.load(f, map_location="cpu", weights_only=False)
            saved_global_step = trainer_state.get("global_step", global_step)
            logger.info("Successfully loaded trainer state")
            if saved_global_step != global_step:
                logger.warning(
                    f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
                )
        else:
            logger.warning(
                f"No trainer_state.pt found at {trainer_state_path}. "
                "This checkpoint was likely saved by an older version."
            )

        policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
        logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
        self.dispatch.load_checkpoint(
            "policy",
            policy_ckpt_dir,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        logger.info(f"Successfully resumed from global_step_{global_step}")
        return global_step

method train_step

train_step(batch: TrainingInputBatch, step: int) -> dict

Execute a single training step: forward_backward + optim_step.

Parameters:

NameTypeDescriptionDefault
batchTrainingInputBatchThe collated training batch.required
stepintCurrent global step (reserved for future use, e.g. scheduling).required

Returns:

TypeDescription
dictDict with loss, grad_norm, and timings.
Source code in skyrl/train/sft_trainer.py:443-464
    def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
        """Execute a single training step: forward_backward + optim_step.

        Args:
            batch: The collated training batch.
            step: Current global step (reserved for future use, e.g. scheduling).

        Returns:
            Dict with ``loss``, ``grad_norm``, and ``timings``.
        """
        timings: dict[str, float] = {}
        with Timer("forward_backward", timings):
            metrics = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
        with Timer("optim_step", timings):
            grad_norm = self.dispatch.optim_step("policy")

        loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
        return {
            "loss": loss_val,
            "grad_norm": grad_norm,
            "timings": timings,
        }

method train

train()

Full training loop: load data, iterate, log, checkpoint.

Source code in skyrl/train/sft_trainer.py:557-673
    def train(self):
        """Full training loop: load data, iterate, log, checkpoint."""
        if self.sft_cfg.dummy_run_full_ctx:
            if self.sft_cfg.resume_from:
                logger.warning("resume_from is ignored in dummy run mode")
            return self._train_dummy()

        tokenized = self.load_dataset()

        batch_size = self.sft_cfg.batch_size
        num_steps = self.sft_cfg.num_steps

        # Early validation: dataset must have at least batch_size examples
        if len(tokenized) < batch_size:
            raise ValueError(
                f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
                f"Reduce batch_size or use more data."
            )

        self._validate_batch_parallelism()

        # Resume from checkpoint if configured
        start_step = self.load_checkpoint()

        # Shuffle data before training
        rng = random.Random(self.sft_cfg.seed)
        rng.shuffle(tokenized)

        # When resuming, start_step is the last *completed* step (checkpoint is
        # saved AFTER the optimizer update), so we begin at start_step + 1 to
        # avoid replaying that step.

        # Replay epoch shuffles for reproducibility on resume
        start_epoch = (start_step * batch_size) // len(tokenized)
        for _ in range(start_epoch):
            rng.shuffle(tokenized)
        current_epoch = start_epoch

        # SkyRL starts counting at step 1
        self.global_step = start_step + 1 if start_step > 0 else 1

        logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
        if start_step > 0:
            logger.info(f"Resuming from step {start_step}")
        while self.global_step <= num_steps:
            all_timings: dict[str, float] = {}

            with Timer("step", all_timings):

                # Data loading with wrap-around
                with Timer("data_loading", all_timings):
                    start_idx = (self.global_step * batch_size) % len(tokenized)
                    end_idx = start_idx + batch_size
                    if end_idx > len(tokenized):
                        batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
                    else:
                        batch_examples = tokenized[start_idx:end_idx]
                    batch = self.collate_batch(batch_examples)

                # Training step
                step_result = self.train_step(batch, self.global_step)
                all_timings.update(step_result["timings"])

            # Compute throughput using actual (non-padding) tokens
            batch_padded_seq_len = batch["sequences"].shape[1]
            actual_num_tokens = batch["attention_mask"].sum().item()
            tokens_per_second = actual_num_tokens / all_timings["step"]

            # Build log dict
            log_dict = {
                "train/loss": step_result["loss"],
                "train/grad_norm": step_result["grad_norm"],
                "train/tokens_per_second": tokens_per_second,
                "train/actual_num_tokens": actual_num_tokens,
                "train/batch_padded_seq_len": batch_padded_seq_len,
            }
            log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})

            # Checkpoint at regular intervals
            if (
                self.sft_cfg.ckpt_path
                and self.sft_cfg.ckpt_interval > 0
                and self.global_step > 0
                and self.global_step % self.sft_cfg.ckpt_interval == 0
            ):
                with Timer("save_checkpoint", all_timings):
                    self.save_checkpoint()
                log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]

            self.tracker.log(log_dict, step=self.global_step, commit=True)

            if self.global_step % 5 == 0:
                logger.info(
                    f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
                )

            # Check for epoch boundary and reshuffle
            epoch = (self.global_step * batch_size) // len(tokenized)
            if epoch > current_epoch:
                for _ in range(epoch - current_epoch):
                    rng.shuffle(tokenized)
                current_epoch = epoch

            self.global_step += 1
        self.global_step = min(self.global_step, num_steps)

        # Save final checkpoint (if checkpointing is enabled)
        if self.sft_cfg.ckpt_path:
            final_step = num_steps
            already_saved = (
                self.sft_cfg.ckpt_interval > 0 and final_step > 0 and final_step % self.sft_cfg.ckpt_interval == 0
            )
            if not already_saved:
                logger.info(f"Saving final checkpoint at step {final_step}")
                self.save_checkpoint()

        logger.info("SFT training complete!")

method abstractmethod save_checkpoint

save_checkpoint()

Save a checkpoint at the given step.

Source code in skyrl/train/sft_trainer.py:675-701
    def save_checkpoint(self):
        """Save a checkpoint at the given step."""
        step = self.global_step
        global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
        policy_save_dir = os.path.join(global_step_folder, "policy")
        io.makedirs(global_step_folder, exist_ok=True)
        logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
        self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)

        # Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
        trainer_state = {
            "global_step": step,
            "config": asdict(self.sft_cfg),
        }
        trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
        with io.open_file(trainer_state_path, "wb") as f:
            torch.save(trainer_state, f)
        logger.info(f"Saved trainer state to {trainer_state_path}")

        # Atomic tracking -- write this last after all saves succeed
        latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
        with io.open_file(latest_file, "w") as f:
            f.write(str(step))
        logger.info(f"Checkpoint saved for global_step_{step}")

        # Clean up old checkpoints after successful save
        cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)

method shutdown

shutdown()

Finish tracking.

Does NOT call ray.shutdown() -- when running inside a Ray task (the normal path via sft_entrypoint), shutting down Ray from within the task would be incorrect. The head-node process owns the Ray lifecycle.

Source code in skyrl/train/sft_trainer.py:707-716
    def shutdown(self):
        """Finish tracking.

        Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
        (the normal path via ``sft_entrypoint``), shutting down Ray from
        within the task would be incorrect.  The head-node process owns
        the Ray lifecycle.
        """
        if self.tracker is not None:
            self.tracker.finish()

Config Bridge

method validate_sft_cfg

validate_sft_cfg(cfg: SFTConfig) -> None

Validate SFT-specific configuration.

Only checks fields that are relevant to SFT training, unlike validate_cfg which includes RL-specific validations.

method build_skyrl_config_for_sft

build_skyrl_config_for_sft(sft_cfg: SFTConfig) -> SkyRLTrainConfig

Map user-facing SFTConfig to the internal SkyRL backend config.

On this page