SFT
Supervised Fine-Tuning configuration and trainer.
Configuration
class SFTPlacementConfig
SFTPlacementConfig(num_nodes: int = 1, num_gpus_per_node: int = 4) -> NoneBases: BaseConfig
Placement configuration for SFT training
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
Attributes:
| Name | Type | Description |
|---|---|---|
num_nodes | int | |
num_gpus_per_node | int |
Source code in skyrl/train/config/sft_config.py:28-33
@dataclass
class SFTPlacementConfig(BaseConfig):
"""Placement configuration for SFT training"""
num_nodes: int = 1
num_gpus_per_node: int = 4attr num_nodes
num_nodes: int = 1from_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
attr num_gpus_per_node
num_gpus_per_node: int = 4class SFTConfig
SFTConfig(model: ModelConfig = (lambda: ModelConfig(path='Qwen/Qwen3-0.6B'))(), optimizer_config: OptimizerConfig = OptimizerConfig(), placement: SFTPlacementConfig = SFTPlacementConfig(), megatron_config: MegatronConfig = (lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2))(), fsdp_config: FSDPConfig = FSDPConfig(), sequence_parallel_size: int = 1, model_config_kwargs: dict = dict(), use_torch_compile: bool = False, record_memory: bool = False, strategy: str = 'megatron', dataset_name: str = 'yahma/alpaca-cleaned', dataset_split: str = 'train[:100]', messages_key: str = 'messages', max_length: int = 512, num_steps: int = 10, batch_size: int = 4, micro_train_batch_size_per_gpu: int = 2, logger: str = 'console', project_name: str = 'skyrl_sft', run_name: str = 'skyrl_sft_run', ckpt_path: str = '', ckpt_interval: int = 0, max_ckpts_to_keep: int = -1, resume_from: str = '', seed: int = 42, use_sample_packing: bool = True, dummy_run_full_ctx: bool = False, dummy_run_max_steps: int = 5) -> NoneBases: BaseConfig
Configuration for SFT training.
Usage::
cfg = SFTConfig( strategy="megatron", placement=SFTPlacementConfig(num_gpus_per_node=4), megatron_config=MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2), )
Or from CLI::
cfg = SFTConfig.from_cli_overrides(sys.argv[1:])
Functions:
| Name | Description |
|---|---|
from_dict_config | Construct a typed BaseConfig from a Hydra DictConfig. |
from_cli_overrides | Construct an SFTConfig from CLI arguments or a dict of overrides. |
Attributes:
| Name | Type | Description |
|---|---|---|
model | ModelConfig | |
optimizer_config | OptimizerConfig | |
placement | SFTPlacementConfig | |
megatron_config | MegatronConfig | |
fsdp_config | FSDPConfig | |
sequence_parallel_size | int | Ulysses sequence parallelism size |
model_config_kwargs | dict | Pass-through kwargs for the HuggingFace model config (FSDP backends). |
use_torch_compile | bool | Apply torch.compile to logits calculation. |
record_memory | bool | Save memory snapshots to {ckpt_path}/memory_snapshots/. |
strategy | str | |
dataset_name | str | |
dataset_split | str | |
messages_key | str | |
max_length | int | |
num_steps | int | |
batch_size | int | |
micro_train_batch_size_per_gpu | int | |
logger | str | |
project_name | str | |
run_name | str | |
ckpt_path | str | |
ckpt_interval | int | |
max_ckpts_to_keep | int | -1 to keep all checkpoints, N to keep only the last N. |
resume_from | str | |
seed | int | |
use_sample_packing | bool | |
dummy_run_full_ctx | bool | |
dummy_run_max_steps | int |
Source code in skyrl/train/config/sft_config.py:36-125
@dataclass
class SFTConfig(BaseConfig):
"""Configuration for SFT training.
Usage::
cfg = SFTConfig(
strategy="megatron",
placement=SFTPlacementConfig(num_gpus_per_node=4),
megatron_config=MegatronConfig(tensor_model_parallel_size=2,
pipeline_model_parallel_size=2),
)
Or from CLI::
cfg = SFTConfig.from_cli_overrides(sys.argv[1:])
"""
@classmethod
def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
"""Construct an SFTConfig from CLI arguments or a dict of overrides.
Parses CLI dotlist arguments via OmegaConf and builds a typed config.
Dataclass field defaults are used for any values not specified.
Args:
args: Either a list of CLI arguments in 'key.path=value' format, or a dict
mapping dot-notation keys to values.
Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}
Returns:
A fully constructed SFTConfig with CLI overrides applied.
"""
if isinstance(args, dict):
args = [f"{k}={v}" for k, v in args.items()]
overrides = OmegaConf.from_cli(args)
return cls.from_dict_config(overrides)
# ---- Reused SkyRL config objects ----
model: ModelConfig = field(default_factory=lambda: ModelConfig(path="Qwen/Qwen3-0.6B"))
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)
megatron_config: MegatronConfig = field(
default_factory=lambda: MegatronConfig(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
)
)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
# Ulysses sequence parallelism
sequence_parallel_size: int = 1
"""Ulysses sequence parallelism size"""
model_config_kwargs: dict = field(default_factory=dict)
"""Pass-through kwargs for the HuggingFace model config (FSDP backends).
For Megatron, use ``megatron_config.transformer_config_kwargs`` instead."""
use_torch_compile: bool = False
"""Apply torch.compile to logits calculation."""
record_memory: bool = False
"""Save memory snapshots to ``{ckpt_path}/memory_snapshots/``.
Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz."""
# ---- SFT-specific flat fields ----
strategy: str = "megatron" # "megatron" or "fsdp2"
dataset_name: str = "yahma/alpaca-cleaned"
dataset_split: str = "train[:100]"
messages_key: str = "messages" # column name for chat-format datasets
max_length: int = 512
num_steps: int = 10
batch_size: int = 4
micro_train_batch_size_per_gpu: int = 2
logger: str = "console" # "console" or "wandb"
project_name: str = "skyrl_sft"
run_name: str = "skyrl_sft_run"
ckpt_path: str = "" # empty string = no checkpointing
ckpt_interval: int = 0
max_ckpts_to_keep: int = -1
"""-1 to keep all checkpoints, N to keep only the last N."""
resume_from: str = "" # "" = no resume, "latest" = latest checkpoint, or path to global_step_N dir
seed: int = 42
# ---- Packing ----
use_sample_packing: bool = True # Pack multiple sequences per batch (requires flash_attn)
# ---- Dummy run / benchmarking ----
dummy_run_full_ctx: bool = False # Skip real data; fabricate full-context sequences
dummy_run_max_steps: int = 5 # Number of steps to run in dummy modefrom_dict_config
from_dict_config(cfg: DictConfig) -> BaseConfigConstruct a typed BaseConfig from a Hydra DictConfig.
method classmethod from_cli_overrides
from_cli_overrides(args: Union[List[str], dict]) -> SFTConfigConstruct an SFTConfig from CLI arguments or a dict of overrides.
Parses CLI dotlist arguments via OmegaConf and builds a typed config. Dataclass field defaults are used for any values not specified.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
args | Union[List[str], dict] | Either a list of CLI arguments in 'key.path=value' format, or a dict mapping dot-notation keys to values. Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B'] Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'} | required |
Returns:
| Type | Description |
|---|---|
| SFTConfig | A fully constructed SFTConfig with CLI overrides applied. |
Source code in skyrl/train/config/sft_config.py:54-74
@classmethod
def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
"""Construct an SFTConfig from CLI arguments or a dict of overrides.
Parses CLI dotlist arguments via OmegaConf and builds a typed config.
Dataclass field defaults are used for any values not specified.
Args:
args: Either a list of CLI arguments in 'key.path=value' format, or a dict
mapping dot-notation keys to values.
Example list: ['strategy=megatron', 'model.path=Qwen/Qwen3-0.6B']
Example dict: {'strategy': 'megatron', 'model.path': 'Qwen/Qwen3-0.6B'}
Returns:
A fully constructed SFTConfig with CLI overrides applied.
"""
if isinstance(args, dict):
args = [f"{k}={v}" for k, v in args.items()]
overrides = OmegaConf.from_cli(args)
return cls.from_dict_config(overrides)attr model
model: ModelConfig = field(default_factory=(lambda: ModelConfig(path='Qwen/Qwen3-0.6B')))attr optimizer_config
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)attr placement
placement: SFTPlacementConfig = field(default_factory=SFTPlacementConfig)attr megatron_config
megatron_config: MegatronConfig = field(default_factory=(lambda: MegatronConfig(tensor_model_parallel_size=2, pipeline_model_parallel_size=2)))attr fsdp_config
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)attr sequence_parallel_size
sequence_parallel_size: int = 1Ulysses sequence parallelism size
attr model_config_kwargs
model_config_kwargs: dict = field(default_factory=dict)Pass-through kwargs for the HuggingFace model config (FSDP backends).
For Megatron, use megatron_config.transformer_config_kwargs instead.
attr use_torch_compile
use_torch_compile: bool = FalseApply torch.compile to logits calculation.
attr record_memory
record_memory: bool = FalseSave memory snapshots to {ckpt_path}/memory_snapshots/.
Visualize by dragging pickle files to https://docs.pytorch.org/memory_viz.
attr strategy
strategy: str = 'megatron'attr dataset_name
dataset_name: str = 'yahma/alpaca-cleaned'attr dataset_split
dataset_split: str = 'train[:100]'attr messages_key
messages_key: str = 'messages'attr max_length
max_length: int = 512attr num_steps
num_steps: int = 10attr batch_size
batch_size: int = 4attr micro_train_batch_size_per_gpu
micro_train_batch_size_per_gpu: int = 2attr logger
logger: str = 'console'attr project_name
project_name: str = 'skyrl_sft'attr run_name
run_name: str = 'skyrl_sft_run'attr ckpt_path
ckpt_path: str = ''attr ckpt_interval
ckpt_interval: int = 0attr max_ckpts_to_keep
max_ckpts_to_keep: int = -1-1 to keep all checkpoints, N to keep only the last N.
attr resume_from
resume_from: str = ''attr seed
seed: int = 42attr use_sample_packing
use_sample_packing: bool = Trueattr dummy_run_full_ctx
dummy_run_full_ctx: bool = Falseattr dummy_run_max_steps
dummy_run_max_steps: int = 5Trainer
class SFTTrainer
SFTTrainer(cfg: SFTConfig, skyrl_cfg: SkyRLTrainConfig | None = None)SFT trainer supporting FSDP and Megatron backends.
Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are fundamentally different: no generation, no critic, no advantages, no KL penalty. Sharing a base class would create confusing dead code paths.
Usage::
trainer = SFTTrainer(SFTConfig(strategy="megatron")) trainer.setup() trainer.train() trainer.shutdown()
Functions:
| Name | Description |
|---|---|
setup | Initialize tokenizer, workers, dispatch, and tracker. |
load_dataset | Load and tokenize the training dataset. |
collate_batch | Collate examples into a TrainingInputBatch with loss normalization. |
load_checkpoint | Load a checkpoint and return the step number to resume from. |
train_step | Execute a single training step: forward_backward + optim_step. |
train | Full training loop: load data, iterate, log, checkpoint. |
save_checkpoint | Save a checkpoint at the given step. |
shutdown | Finish tracking. |
Attributes:
| Name | Type | Description |
|---|---|---|
sft_cfg | ||
cfg | ||
tokenizer | ||
dispatch | WorkerDispatch | None | |
tracker | Tracking | None | |
global_step |
Source code in skyrl/train/sft_trainer.py:193-716
class SFTTrainer:
"""SFT trainer supporting FSDP and Megatron backends.
Unlike RayPPOTrainer, this does NOT subclass it. SFT's concerns are
fundamentally different: no generation, no critic, no advantages, no
KL penalty. Sharing a base class would create confusing dead code paths.
Usage::
trainer = SFTTrainer(SFTConfig(strategy="megatron"))
trainer.setup()
trainer.train()
trainer.shutdown()
"""
def __init__(self, cfg: SFTConfig, skyrl_cfg: SkyRLTrainConfig | None = None):
self.sft_cfg = cfg
# Accept a pre-built bridge config to avoid redundant rebuilds.
# When not provided (e.g. standalone usage), build it here.
self.cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)
self.tokenizer = None
self.dispatch: WorkerDispatch | None = None
self.tracker: Tracking | None = None
self.global_step = 0
# ------------------------------------------------------------------ #
# Setup
# ------------------------------------------------------------------ #
def setup(self):
"""Initialize tokenizer, workers, dispatch, and tracker.
Ray must already be initialized before calling this (either via
``initialize_ray`` on the head node or inside a Ray task).
"""
self.tokenizer = get_tokenizer(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
padding_side="left",
)
self._init_workers()
self._init_tracker()
def _init_workers(self):
"""Create PPORayActorGroup and WorkerDispatch.
Selects the correct PolicyWorker based on strategy.
"""
if self.sft_cfg.strategy == "megatron":
from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
PolicyWorker,
)
else:
from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
num_gpus = self.sft_cfg.placement.num_gpus_per_node
raw_pg = placement_group(
[{"GPU": num_gpus, "CPU": num_gpus}] * self.sft_cfg.placement.num_nodes,
strategy="PACK",
)
get_ray_pg_ready_with_timeout(raw_pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
pg = ResolvedPlacementGroup(raw_pg)
actor_group = PPORayActorGroup(
self.cfg.trainer,
num_nodes=self.sft_cfg.placement.num_nodes,
num_gpus_per_node=num_gpus,
ray_actor_type=PolicyWorker,
pg=pg,
num_gpus_per_actor=1,
colocate_all=False,
sequence_parallel_size=self.cfg.trainer.policy.sequence_parallel_size,
record_memory=self.cfg.trainer.policy.record_memory,
)
num_training_steps = (
self.sft_cfg.dummy_run_max_steps if self.sft_cfg.dummy_run_full_ctx else self.sft_cfg.num_steps
)
ray.get(
actor_group.async_init_model(
self.sft_cfg.model.path,
num_training_steps=num_training_steps,
)
)
ray.get(actor_group.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
self.dispatch = WorkerDispatch(self.cfg, policy_actor_group=actor_group)
def _init_tracker(self):
self.tracker = Tracking(
project_name=self.cfg.trainer.project_name,
experiment_name=self.cfg.trainer.run_name,
backends=self.cfg.trainer.logger,
config=self.sft_cfg,
)
# ------------------------------------------------------------------ #
# Data
# ------------------------------------------------------------------ #
def _load_and_tokenize(self, dataset_name: str, dataset_split: str) -> list:
"""Load and tokenize a dataset.
Auto-detects the dataset format based on column names:
- If a ``messages_key`` column exists, uses chat-format tokenization.
- If ``instruction`` and ``output`` columns exist, uses Alpaca-format
tokenization.
Args:
dataset_name: HuggingFace dataset name (e.g. ``"yahma/alpaca-cleaned"``).
dataset_split: Dataset split (e.g. ``"train[:100]"`` or ``"test"``).
Returns a list of tokenized examples (dicts with ``input_ids``,
``attention_mask``, ``num_actions``).
"""
logger.info(f"Loading dataset '{dataset_name}' split='{dataset_split}'...")
dataset = load_dataset(dataset_name, split=dataset_split)
columns = dataset.column_names
logger.info("Tokenizing dataset...")
if self.sft_cfg.messages_key in columns:
# Chat format
tokenized = [
tokenize_chat_example(ex, self.tokenizer, self.sft_cfg.max_length, self.sft_cfg.messages_key)
for ex in dataset
]
elif "instruction" in columns and "output" in columns:
# Alpaca format
tokenized = [tokenize_sft_example(ex, self.tokenizer, self.sft_cfg.max_length) for ex in dataset]
else:
raise ValueError(
f"Unrecognized dataset format. Expected '{self.sft_cfg.messages_key}' column "
f"(chat format) or 'instruction'+'output' columns (Alpaca format). "
f"Found columns: {columns}"
)
tokenized = [ex for ex in tokenized if ex is not None]
logger.info(f"Tokenized {len(tokenized)} examples (filtered from {len(dataset)})")
return tokenized
def load_dataset(self) -> list:
"""Load and tokenize the training dataset."""
return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)
def collate_batch(self, examples: list) -> TrainingInputBatch:
"""Collate examples into a TrainingInputBatch with loss normalization.
Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss
produces a per-non-pad-token mean, matching the standard convention.
NOTE: The scaling factor is ``batch_size / (micro_batch_size * total_nonpad)``
where ``total_nonpad`` is the count of non-masked (loss-contributing)
tokens in the full batch. This accounts for the ``microbatch_weight``
(FSDP) or ``1/num_microbatches`` (Megatron) applied during gradient
accumulation so that the effective gradient equals
``d[sum(-log_probs_on_nonpad) / total_nonpad]``.
"""
batch = collate_sft_batch(examples, self.tokenizer)
# Loss normalization: divide by non-pad token count (not padded seq length)
# NOTE (sumanthrh): This specific scaling factor is because SkyRL's workers internally normalize
# by number of micro batches, but aggregate otherwise
micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
total_nonpad = max(batch["loss_mask"].sum().item(), 1)
batch["loss_mask"] = batch["loss_mask"].float() * (self.sft_cfg.batch_size / (micro_batch_size * total_nonpad))
return batch
# ------------------------------------------------------------------ #
# Checkpoint resume
# ------------------------------------------------------------------ #
def load_checkpoint(self) -> int:
"""Load a checkpoint and return the step number to resume from.
Behaviour depends on ``sft_cfg.resume_from``:
- ``""`` (empty): no resume, return 0.
- ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
- otherwise: treat as a direct path to a ``global_step_N`` directory.
Returns:
The global step to resume from (0 if no checkpoint loaded).
"""
resume_from = self.sft_cfg.resume_from
if not resume_from:
return 0
if resume_from == "latest":
if not self.sft_cfg.ckpt_path:
logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
return 0
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_file):
logger.info("No latest checkpoint marker found, starting from scratch")
return 0
with io.open_file(latest_file, "r") as f:
ckpt_step = int(f.read().strip())
checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
# Validate consistency: ensure no stale checkpoint folders from prior runs
validate_consistency_for_latest_checkpoint(
self.sft_cfg.ckpt_path,
ckpt_step,
checkpoint_path,
latest_file,
self.sft_cfg.ckpt_interval,
)
else:
checkpoint_path = resume_from
if not io.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
global_step = extract_step_from_path(checkpoint_path)
if global_step == -1:
raise ValueError(
f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
)
# Load and validate trainer state if available
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
if io.exists(trainer_state_path):
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(
f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
)
else:
logger.warning(
f"No trainer_state.pt found at {trainer_state_path}. "
"This checkpoint was likely saved by an older version."
)
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info(f"Successfully resumed from global_step_{global_step}")
return global_step
# ------------------------------------------------------------------ #
# Training
# ------------------------------------------------------------------ #
def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
"""Execute a single training step: forward_backward + optim_step.
Args:
batch: The collated training batch.
step: Current global step (reserved for future use, e.g. scheduling).
Returns:
Dict with ``loss``, ``grad_norm``, and ``timings``.
"""
timings: dict[str, float] = {}
with Timer("forward_backward", timings):
metrics = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
with Timer("optim_step", timings):
grad_norm = self.dispatch.optim_step("policy")
loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
return {
"loss": loss_val,
"grad_norm": grad_norm,
"timings": timings,
}
def _validate_batch_parallelism(self):
"""Validate that batch_size is compatible with data-parallel and micro-batch sizes."""
batch_size = self.sft_cfg.batch_size
total_gpus = self.sft_cfg.placement.num_nodes * self.sft_cfg.placement.num_gpus_per_node
if self.sft_cfg.strategy == "megatron":
tp = self.sft_cfg.megatron_config.tensor_model_parallel_size
pp = self.sft_cfg.megatron_config.pipeline_model_parallel_size
dp_size = total_gpus // (tp * pp)
else:
# FSDP: all GPUs are data-parallel
dp_size = total_gpus
if batch_size % dp_size != 0:
raise ValueError(f"batch_size ({batch_size}) must be divisible by data-parallel size ({dp_size})")
per_dp_batch = batch_size // dp_size
micro_batch = self.sft_cfg.micro_train_batch_size_per_gpu
if per_dp_batch % micro_batch != 0:
raise ValueError(
f"batch_size / dp_size ({per_dp_batch}) must be divisible by "
f"micro_train_batch_size_per_gpu ({micro_batch})"
)
def _build_dummy_batch(self) -> TrainingInputBatch:
"""Build a dummy batch of random full-context sequences for benchmarking."""
batch_size = self.sft_cfg.batch_size
max_length = self.sft_cfg.max_length
micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
vocab_size = self.tokenizer.vocab_size
# num_actions is max_length - 1 because the autoregressive model
# produces log-probs for positions 1..T (predicting next token),
# so the first token has no corresponding log-prob.
num_actions = max_length - 1
sequences = torch.randint(0, vocab_size, (batch_size, max_length), dtype=torch.long)
attention_mask = torch.ones(batch_size, max_length, dtype=torch.long)
# All tokens are non-pad in the dummy batch, so total_nonpad = batch_size * num_actions.
# Scaling = batch_size / (micro_batch_size * total_nonpad)
# = 1 / (micro_batch_size * num_actions)
total_nonpad = batch_size * num_actions
loss_mask = torch.ones(batch_size, num_actions, dtype=torch.float) * (
batch_size / (micro_batch_size * total_nonpad)
)
batch = TrainingInputBatch(
{
"sequences": sequences,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
}
)
batch.metadata = {"response_length": num_actions}
return batch
def _train_dummy(self):
"""Dummy training loop for benchmarking. Skips real data, checkpoints, and resume."""
self._validate_batch_parallelism()
batch = self._build_dummy_batch()
num_steps = self.sft_cfg.dummy_run_max_steps
logger.info(
f"Starting dummy SFT training for {num_steps} steps "
f"(batch_size={self.sft_cfg.batch_size}, max_length={self.sft_cfg.max_length})..."
)
for step in range(num_steps):
all_timings: dict[str, float] = {}
with Timer("step", all_timings):
step_result = self.train_step(batch, step)
all_timings.update(step_result["timings"])
actual_num_tokens = batch["attention_mask"].sum().item()
tokens_per_second = actual_num_tokens / all_timings["step"]
log_dict = {
"train/loss": step_result["loss"],
"train/grad_norm": step_result["grad_norm"],
"train/tokens_per_second": tokens_per_second,
"train/actual_num_tokens": actual_num_tokens,
}
log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
self.tracker.log(log_dict, step=step, commit=True)
logger.info(
f"Step {step}: loss={step_result['loss']:.4f}, "
f"grad_norm={step_result['grad_norm']}, "
f"tokens_per_second={tokens_per_second:.0f}"
)
logger.info("Dummy SFT training complete!")
def train(self):
"""Full training loop: load data, iterate, log, checkpoint."""
if self.sft_cfg.dummy_run_full_ctx:
if self.sft_cfg.resume_from:
logger.warning("resume_from is ignored in dummy run mode")
return self._train_dummy()
tokenized = self.load_dataset()
batch_size = self.sft_cfg.batch_size
num_steps = self.sft_cfg.num_steps
# Early validation: dataset must have at least batch_size examples
if len(tokenized) < batch_size:
raise ValueError(
f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
f"Reduce batch_size or use more data."
)
self._validate_batch_parallelism()
# Resume from checkpoint if configured
start_step = self.load_checkpoint()
# Shuffle data before training
rng = random.Random(self.sft_cfg.seed)
rng.shuffle(tokenized)
# When resuming, start_step is the last *completed* step (checkpoint is
# saved AFTER the optimizer update), so we begin at start_step + 1 to
# avoid replaying that step.
# Replay epoch shuffles for reproducibility on resume
start_epoch = (start_step * batch_size) // len(tokenized)
for _ in range(start_epoch):
rng.shuffle(tokenized)
current_epoch = start_epoch
# SkyRL starts counting at step 1
self.global_step = start_step + 1 if start_step > 0 else 1
logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
if start_step > 0:
logger.info(f"Resuming from step {start_step}")
while self.global_step <= num_steps:
all_timings: dict[str, float] = {}
with Timer("step", all_timings):
# Data loading with wrap-around
with Timer("data_loading", all_timings):
start_idx = (self.global_step * batch_size) % len(tokenized)
end_idx = start_idx + batch_size
if end_idx > len(tokenized):
batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
else:
batch_examples = tokenized[start_idx:end_idx]
batch = self.collate_batch(batch_examples)
# Training step
step_result = self.train_step(batch, self.global_step)
all_timings.update(step_result["timings"])
# Compute throughput using actual (non-padding) tokens
batch_padded_seq_len = batch["sequences"].shape[1]
actual_num_tokens = batch["attention_mask"].sum().item()
tokens_per_second = actual_num_tokens / all_timings["step"]
# Build log dict
log_dict = {
"train/loss": step_result["loss"],
"train/grad_norm": step_result["grad_norm"],
"train/tokens_per_second": tokens_per_second,
"train/actual_num_tokens": actual_num_tokens,
"train/batch_padded_seq_len": batch_padded_seq_len,
}
log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
# Checkpoint at regular intervals
if (
self.sft_cfg.ckpt_path
and self.sft_cfg.ckpt_interval > 0
and self.global_step > 0
and self.global_step % self.sft_cfg.ckpt_interval == 0
):
with Timer("save_checkpoint", all_timings):
self.save_checkpoint()
log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]
self.tracker.log(log_dict, step=self.global_step, commit=True)
if self.global_step % 5 == 0:
logger.info(
f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
)
# Check for epoch boundary and reshuffle
epoch = (self.global_step * batch_size) // len(tokenized)
if epoch > current_epoch:
for _ in range(epoch - current_epoch):
rng.shuffle(tokenized)
current_epoch = epoch
self.global_step += 1
self.global_step = min(self.global_step, num_steps)
# Save final checkpoint (if checkpointing is enabled)
if self.sft_cfg.ckpt_path:
final_step = num_steps
already_saved = (
self.sft_cfg.ckpt_interval > 0 and final_step > 0 and final_step % self.sft_cfg.ckpt_interval == 0
)
if not already_saved:
logger.info(f"Saving final checkpoint at step {final_step}")
self.save_checkpoint()
logger.info("SFT training complete!")
def save_checkpoint(self):
"""Save a checkpoint at the given step."""
step = self.global_step
global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
io.makedirs(global_step_folder, exist_ok=True)
logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
trainer_state = {
"global_step": step,
"config": asdict(self.sft_cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking -- write this last after all saves succeed
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_file, "w") as f:
f.write(str(step))
logger.info(f"Checkpoint saved for global_step_{step}")
# Clean up old checkpoints after successful save
cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)
# ------------------------------------------------------------------ #
# Lifecycle
# ------------------------------------------------------------------ #
def shutdown(self):
"""Finish tracking.
Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
(the normal path via ``sft_entrypoint``), shutting down Ray from
within the task would be incorrect. The head-node process owns
the Ray lifecycle.
"""
if self.tracker is not None:
self.tracker.finish()attr sft_cfg
sft_cfg = cfgattr cfg
cfg = skyrl_cfg if skyrl_cfg is not None else build_skyrl_config_for_sft(cfg)attr tokenizer
tokenizer = Noneattr dispatch
dispatch: WorkerDispatch | None = Noneattr tracker
tracker: Tracking | None = Noneattr global_step
global_step = 0method setup
setup()Initialize tokenizer, workers, dispatch, and tracker.
Ray must already be initialized before calling this (either via
initialize_ray on the head node or inside a Ray task).
Source code in skyrl/train/sft_trainer.py:222-235
def setup(self):
"""Initialize tokenizer, workers, dispatch, and tracker.
Ray must already be initialized before calling this (either via
``initialize_ray`` on the head node or inside a Ray task).
"""
self.tokenizer = get_tokenizer(
self.cfg.trainer.policy.model.path,
trust_remote_code=True,
use_fast=not self.cfg.trainer.disable_fast_tokenizer,
padding_side="left",
)
self._init_workers()
self._init_tracker()method load_dataset
load_dataset() -> listLoad and tokenize the training dataset.
Source code in skyrl/train/sft_trainer.py:334-336
def load_dataset(self) -> list:
"""Load and tokenize the training dataset."""
return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)method collate_batch
collate_batch(examples: list) -> TrainingInputBatchCollate examples into a TrainingInputBatch with loss normalization.
Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss produces a per-non-pad-token mean, matching the standard convention.
NOTE: The scaling factor is batch_size / (micro_batch_size * total_nonpad)
where total_nonpad is the count of non-masked (loss-contributing)
tokens in the full batch. This accounts for the microbatch_weight
(FSDP) or 1/num_microbatches (Megatron) applied during gradient
accumulation so that the effective gradient equals
d[sum(-log_probs_on_nonpad) / total_nonpad].
Source code in skyrl/train/sft_trainer.py:338-358
def collate_batch(self, examples: list) -> TrainingInputBatch:
"""Collate examples into a TrainingInputBatch with loss normalization.
Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss
produces a per-non-pad-token mean, matching the standard convention.
NOTE: The scaling factor is ``batch_size / (micro_batch_size * total_nonpad)``
where ``total_nonpad`` is the count of non-masked (loss-contributing)
tokens in the full batch. This accounts for the ``microbatch_weight``
(FSDP) or ``1/num_microbatches`` (Megatron) applied during gradient
accumulation so that the effective gradient equals
``d[sum(-log_probs_on_nonpad) / total_nonpad]``.
"""
batch = collate_sft_batch(examples, self.tokenizer)
# Loss normalization: divide by non-pad token count (not padded seq length)
# NOTE (sumanthrh): This specific scaling factor is because SkyRL's workers internally normalize
# by number of micro batches, but aggregate otherwise
micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
total_nonpad = max(batch["loss_mask"].sum().item(), 1)
batch["loss_mask"] = batch["loss_mask"].float() * (self.sft_cfg.batch_size / (micro_batch_size * total_nonpad))
return batchmethod abstractmethod load_checkpoint
load_checkpoint() -> intLoad a checkpoint and return the step number to resume from.
Behaviour depends on sft_cfg.resume_from:
""(empty): no resume, return 0."latest": readlatest_ckpt_global_step.txtfromckpt_path.- otherwise: treat as a direct path to a
global_step_Ndirectory.
Returns:
| Type | Description |
|---|---|
| int | The global step to resume from (0 if no checkpoint loaded). |
Source code in skyrl/train/sft_trainer.py:364-437
def load_checkpoint(self) -> int:
"""Load a checkpoint and return the step number to resume from.
Behaviour depends on ``sft_cfg.resume_from``:
- ``""`` (empty): no resume, return 0.
- ``"latest"``: read ``latest_ckpt_global_step.txt`` from ``ckpt_path``.
- otherwise: treat as a direct path to a ``global_step_N`` directory.
Returns:
The global step to resume from (0 if no checkpoint loaded).
"""
resume_from = self.sft_cfg.resume_from
if not resume_from:
return 0
if resume_from == "latest":
if not self.sft_cfg.ckpt_path:
logger.info("resume_from='latest' but ckpt_path is empty, starting from scratch")
return 0
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
if not io.exists(latest_file):
logger.info("No latest checkpoint marker found, starting from scratch")
return 0
with io.open_file(latest_file, "r") as f:
ckpt_step = int(f.read().strip())
checkpoint_path = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{ckpt_step}")
# Validate consistency: ensure no stale checkpoint folders from prior runs
validate_consistency_for_latest_checkpoint(
self.sft_cfg.ckpt_path,
ckpt_step,
checkpoint_path,
latest_file,
self.sft_cfg.ckpt_interval,
)
else:
checkpoint_path = resume_from
if not io.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
global_step = extract_step_from_path(checkpoint_path)
if global_step == -1:
raise ValueError(
f"Cannot extract step number from checkpoint path: {checkpoint_path}. "
f"Expected a directory named '{GLOBAL_STEP_PREFIX}<N>'."
)
# Load and validate trainer state if available
trainer_state_path = os.path.join(checkpoint_path, "trainer_state.pt")
if io.exists(trainer_state_path):
with io.open_file(trainer_state_path, "rb") as f:
trainer_state = torch.load(f, map_location="cpu", weights_only=False)
saved_global_step = trainer_state.get("global_step", global_step)
logger.info("Successfully loaded trainer state")
if saved_global_step != global_step:
logger.warning(
f"Global step mismatch: path={global_step}, saved={saved_global_step}. Using path value."
)
else:
logger.warning(
f"No trainer_state.pt found at {trainer_state_path}. "
"This checkpoint was likely saved by an older version."
)
policy_ckpt_dir = os.path.join(checkpoint_path, "policy")
logger.info(f"Loading checkpoint from {checkpoint_path} (step {global_step})")
self.dispatch.load_checkpoint(
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info(f"Successfully resumed from global_step_{global_step}")
return global_stepmethod train_step
train_step(batch: TrainingInputBatch, step: int) -> dictExecute a single training step: forward_backward + optim_step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch | TrainingInputBatch | The collated training batch. | required |
step | int | Current global step (reserved for future use, e.g. scheduling). | required |
Returns:
| Type | Description |
|---|---|
| dict | Dict with loss, grad_norm, and timings. |
Source code in skyrl/train/sft_trainer.py:443-464
def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
"""Execute a single training step: forward_backward + optim_step.
Args:
batch: The collated training batch.
step: Current global step (reserved for future use, e.g. scheduling).
Returns:
Dict with ``loss``, ``grad_norm``, and ``timings``.
"""
timings: dict[str, float] = {}
with Timer("forward_backward", timings):
metrics = self.dispatch.forward_backward("policy", batch, loss_fn="cross_entropy")
with Timer("optim_step", timings):
grad_norm = self.dispatch.optim_step("policy")
loss_val = metrics.get("final_loss", metrics.get("loss", float("nan")))
return {
"loss": loss_val,
"grad_norm": grad_norm,
"timings": timings,
}method train
train()Full training loop: load data, iterate, log, checkpoint.
Source code in skyrl/train/sft_trainer.py:557-673
def train(self):
"""Full training loop: load data, iterate, log, checkpoint."""
if self.sft_cfg.dummy_run_full_ctx:
if self.sft_cfg.resume_from:
logger.warning("resume_from is ignored in dummy run mode")
return self._train_dummy()
tokenized = self.load_dataset()
batch_size = self.sft_cfg.batch_size
num_steps = self.sft_cfg.num_steps
# Early validation: dataset must have at least batch_size examples
if len(tokenized) < batch_size:
raise ValueError(
f"Dataset has {len(tokenized)} examples after tokenization, but batch_size={batch_size}. "
f"Reduce batch_size or use more data."
)
self._validate_batch_parallelism()
# Resume from checkpoint if configured
start_step = self.load_checkpoint()
# Shuffle data before training
rng = random.Random(self.sft_cfg.seed)
rng.shuffle(tokenized)
# When resuming, start_step is the last *completed* step (checkpoint is
# saved AFTER the optimizer update), so we begin at start_step + 1 to
# avoid replaying that step.
# Replay epoch shuffles for reproducibility on resume
start_epoch = (start_step * batch_size) // len(tokenized)
for _ in range(start_epoch):
rng.shuffle(tokenized)
current_epoch = start_epoch
# SkyRL starts counting at step 1
self.global_step = start_step + 1 if start_step > 0 else 1
logger.info(f"Starting SFT training for {num_steps} steps (batch_size={batch_size})...")
if start_step > 0:
logger.info(f"Resuming from step {start_step}")
while self.global_step <= num_steps:
all_timings: dict[str, float] = {}
with Timer("step", all_timings):
# Data loading with wrap-around
with Timer("data_loading", all_timings):
start_idx = (self.global_step * batch_size) % len(tokenized)
end_idx = start_idx + batch_size
if end_idx > len(tokenized):
batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
else:
batch_examples = tokenized[start_idx:end_idx]
batch = self.collate_batch(batch_examples)
# Training step
step_result = self.train_step(batch, self.global_step)
all_timings.update(step_result["timings"])
# Compute throughput using actual (non-padding) tokens
batch_padded_seq_len = batch["sequences"].shape[1]
actual_num_tokens = batch["attention_mask"].sum().item()
tokens_per_second = actual_num_tokens / all_timings["step"]
# Build log dict
log_dict = {
"train/loss": step_result["loss"],
"train/grad_norm": step_result["grad_norm"],
"train/tokens_per_second": tokens_per_second,
"train/actual_num_tokens": actual_num_tokens,
"train/batch_padded_seq_len": batch_padded_seq_len,
}
log_dict.update({f"timing/{k}": v for k, v in all_timings.items()})
# Checkpoint at regular intervals
if (
self.sft_cfg.ckpt_path
and self.sft_cfg.ckpt_interval > 0
and self.global_step > 0
and self.global_step % self.sft_cfg.ckpt_interval == 0
):
with Timer("save_checkpoint", all_timings):
self.save_checkpoint()
log_dict["timing/save_checkpoint"] = all_timings["save_checkpoint"]
self.tracker.log(log_dict, step=self.global_step, commit=True)
if self.global_step % 5 == 0:
logger.info(
f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
)
# Check for epoch boundary and reshuffle
epoch = (self.global_step * batch_size) // len(tokenized)
if epoch > current_epoch:
for _ in range(epoch - current_epoch):
rng.shuffle(tokenized)
current_epoch = epoch
self.global_step += 1
self.global_step = min(self.global_step, num_steps)
# Save final checkpoint (if checkpointing is enabled)
if self.sft_cfg.ckpt_path:
final_step = num_steps
already_saved = (
self.sft_cfg.ckpt_interval > 0 and final_step > 0 and final_step % self.sft_cfg.ckpt_interval == 0
)
if not already_saved:
logger.info(f"Saving final checkpoint at step {final_step}")
self.save_checkpoint()
logger.info("SFT training complete!")method abstractmethod save_checkpoint
save_checkpoint()Save a checkpoint at the given step.
Source code in skyrl/train/sft_trainer.py:675-701
def save_checkpoint(self):
"""Save a checkpoint at the given step."""
step = self.global_step
global_step_folder = os.path.join(self.sft_cfg.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}")
policy_save_dir = os.path.join(global_step_folder, "policy")
io.makedirs(global_step_folder, exist_ok=True)
logger.info(f"Saving checkpoint at step {step} to {global_step_folder}")
self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer)
# Save trainer state for cross-validation on resume (mirrors PPO's trainer_state.pt)
trainer_state = {
"global_step": step,
"config": asdict(self.sft_cfg),
}
trainer_state_path = os.path.join(global_step_folder, "trainer_state.pt")
with io.open_file(trainer_state_path, "wb") as f:
torch.save(trainer_state, f)
logger.info(f"Saved trainer state to {trainer_state_path}")
# Atomic tracking -- write this last after all saves succeed
latest_file = os.path.join(self.sft_cfg.ckpt_path, "latest_ckpt_global_step.txt")
with io.open_file(latest_file, "w") as f:
f.write(str(step))
logger.info(f"Checkpoint saved for global_step_{step}")
# Clean up old checkpoints after successful save
cleanup_old_checkpoints(self.sft_cfg.ckpt_path, self.sft_cfg.max_ckpts_to_keep)method shutdown
shutdown()Finish tracking.
Does NOT call ray.shutdown() -- when running inside a Ray task
(the normal path via sft_entrypoint), shutting down Ray from
within the task would be incorrect. The head-node process owns
the Ray lifecycle.
Source code in skyrl/train/sft_trainer.py:707-716
def shutdown(self):
"""Finish tracking.
Does NOT call ``ray.shutdown()`` -- when running inside a Ray task
(the normal path via ``sft_entrypoint``), shutting down Ray from
within the task would be incorrect. The head-node process owns
the Ray lifecycle.
"""
if self.tracker is not None:
self.tracker.finish()Config Bridge
method validate_sft_cfg
validate_sft_cfg(cfg: SFTConfig) -> NoneValidate SFT-specific configuration.
Only checks fields that are relevant to SFT training, unlike
validate_cfg which includes RL-specific validations.
method build_skyrl_config_for_sft
build_skyrl_config_for_sft(sft_cfg: SFTConfig) -> SkyRLTrainConfigMap user-facing SFTConfig to the internal SkyRL backend config.