Callbacks
Training callbacks for the SFT and RL trainers.
Callback Interface
Subclass TrainingCallback and override the events you need.
class TrainingCallback
Base class. Subclass and override the events you care about.
Functions:
| Name | Description |
|---|---|
on_train_start | Fires once before the training loop begins (after checkpoint resume). |
on_train_end | Fires once after the training loop and all final saves/eval complete. |
on_epoch_start | Fires at the start of each epoch. |
on_epoch_end | Fires at the end of each epoch. |
on_step_start | Fires before each training step. callback_input.batch is populated. |
on_step_end | Fires after each training step. callback_input.batch and .metrics are populated. |
on_eval_start | Fires before an evaluation pass. |
on_eval_end | Fires after an evaluation pass. callback_input.metrics holds the eval metrics. |
on_save | Fires after a checkpoint is written. callback_input.ckpt_path is the folder path. |
on_log | Fires before metrics are committed to the tracker. Mutate callback_input.logs to add fields. |
Source code in skyrl/train/utils/callbacks.py:83-114
class TrainingCallback:
"""Base class. Subclass and override the events you care about."""
def on_train_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires once before the training loop begins (after checkpoint resume)."""
def on_train_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires once after the training loop and all final saves/eval complete."""
def on_epoch_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires at the start of each epoch."""
def on_epoch_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires at the end of each epoch."""
def on_step_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires before each training step. ``callback_input.batch`` is populated."""
def on_step_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires after each training step. ``callback_input.batch`` and ``.metrics`` are populated."""
def on_eval_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires before an evaluation pass."""
def on_eval_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires after an evaluation pass. ``callback_input.metrics`` holds the eval metrics."""
def on_save(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires after a checkpoint is written. ``callback_input.ckpt_path`` is the folder path."""
def on_log(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires before metrics are committed to the tracker. Mutate ``callback_input.logs`` to add fields."""method on_train_start
on_train_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires once before the training loop begins (after checkpoint resume).
Source code in skyrl/train/utils/callbacks.py:86-87
def on_train_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires once before the training loop begins (after checkpoint resume)."""method on_train_end
on_train_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires once after the training loop and all final saves/eval complete.
Source code in skyrl/train/utils/callbacks.py:89-90
def on_train_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires once after the training loop and all final saves/eval complete."""method on_epoch_start
on_epoch_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires at the start of each epoch.
Source code in skyrl/train/utils/callbacks.py:92-93
def on_epoch_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires at the start of each epoch."""method on_epoch_end
on_epoch_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires at the end of each epoch.
Source code in skyrl/train/utils/callbacks.py:95-96
def on_epoch_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires at the end of each epoch."""method on_step_start
on_step_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires before each training step. callback_input.batch is populated.
Source code in skyrl/train/utils/callbacks.py:98-99
def on_step_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires before each training step. ``callback_input.batch`` is populated."""method on_step_end
on_step_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires after each training step. callback_input.batch and .metrics are populated.
Source code in skyrl/train/utils/callbacks.py:101-102
def on_step_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires after each training step. ``callback_input.batch`` and ``.metrics`` are populated."""method on_eval_start
on_eval_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires before an evaluation pass.
Source code in skyrl/train/utils/callbacks.py:104-105
def on_eval_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires before an evaluation pass."""method on_eval_end
on_eval_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires after an evaluation pass. callback_input.metrics holds the eval metrics.
Source code in skyrl/train/utils/callbacks.py:107-108
def on_eval_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires after an evaluation pass. ``callback_input.metrics`` holds the eval metrics."""method on_save
on_save(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires after a checkpoint is written. callback_input.ckpt_path is the folder path.
Source code in skyrl/train/utils/callbacks.py:110-111
def on_save(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires after a checkpoint is written. ``callback_input.ckpt_path`` is the folder path."""method on_log
on_log(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneFires before metrics are committed to the tracker. Mutate callback_input.logs to add fields.
Source code in skyrl/train/utils/callbacks.py:113-114
def on_log(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
"""Fires before metrics are committed to the tracker. Mutate ``callback_input.logs`` to add fields."""Callback I/O
State passed to each event and the mutable control flags callbacks can set.
class CallbackInput
CallbackInput(global_step: int, epoch: int, total_steps: int, steps_per_epoch: int, batch: Optional['TrainingInputBatch'] = None, metrics: Optional[Dict[str, Any]] = None, logs: Optional[Dict[str, Any]] = None, ckpt_path: Optional[str] = None) -> NoneState passed to every callback event.
The trainer rebuilds this before each event dispatch. Read-only from the
callback's perspective. Per-event fields are None when not relevant
to the firing event - callbacks should null-check the fields they use.
Attributes:
| Name | Type | Description |
|---|---|---|
global_step | int | |
epoch | int | |
total_steps | int | |
steps_per_epoch | int | |
batch | Optional['TrainingInputBatch'] | |
metrics | Optional[Dict[str, Any]] | |
logs | Optional[Dict[str, Any]] | |
ckpt_path | Optional[str] |
Source code in skyrl/train/utils/callbacks.py:34-60
@dataclass
class CallbackInput:
"""State passed to every callback event.
The trainer rebuilds this before each event dispatch. Read-only from the
callback's perspective. Per-event fields are ``None`` when not relevant
to the firing event - callbacks should null-check the fields they use.
"""
# Always populated
global_step: int
epoch: int
total_steps: int
steps_per_epoch: int
# Step events
batch: Optional["TrainingInputBatch"] = None
# Step end / eval end
metrics: Optional[Dict[str, Any]] = None
# on_log only - the dict the trainer is about to commit. Callbacks may
# mutate it in place to add extra fields.
logs: Optional[Dict[str, Any]] = None
# on_save only
ckpt_path: Optional[str] = Noneattr global_step
global_step: intattr epoch
epoch: intattr total_steps
total_steps: intattr steps_per_epoch
steps_per_epoch: intattr batch
batch: Optional['TrainingInputBatch'] = Noneattr property metrics
metrics: Optional[Dict[str, Any]] = Noneattr logs
logs: Optional[Dict[str, Any]] = Noneattr ckpt_path
ckpt_path: Optional[str] = Noneclass TrainingControl
TrainingControl(should_save: bool = False, should_evaluate: bool = False) -> NoneMutable flags callbacks can set to influence the trainer.
The trainer reads these once per step — right after on_step_end — then
honors and resets them. As a result, flags are only acted on for the
current step when set during on_step_end (or earlier in the same
step). Setting a flag from a later event in the same step (on_eval_end,
on_save, on_log) takes effect on the next step's read, so prefer
setting control flags from on_step_end.
Functions:
| Name | Description |
|---|---|
reset |
Attributes:
| Name | Type | Description |
|---|---|---|
should_save | bool | |
should_evaluate | bool |
Source code in skyrl/train/utils/callbacks.py:63-80
@dataclass
class TrainingControl:
"""Mutable flags callbacks can set to influence the trainer.
The trainer reads these once per step — right after ``on_step_end`` — then
honors and resets them. As a result, flags are only acted on for the
*current* step when set during ``on_step_end`` (or earlier in the same
step). Setting a flag from a later event in the same step (``on_eval_end``,
``on_save``, ``on_log``) takes effect on the *next* step's read, so prefer
setting control flags from ``on_step_end``.
"""
should_save: bool = False
should_evaluate: bool = False
def reset(self) -> None:
self.should_save = False
self.should_evaluate = Falseattr should_save
should_save: bool = Falseattr should_evaluate
should_evaluate: bool = Falsemethod reset
reset() -> NoneSource code in skyrl/train/utils/callbacks.py:78-80
def reset(self) -> None:
self.should_save = False
self.should_evaluate = FalseDispatch
Fan-out handler that trainers hold internally.
class CallbackHandler
CallbackHandler(callbacks: Optional[List[TrainingCallback]] = None)Bases: TrainingCallback
Fan-out dispatcher. Itself a TrainingCallback (composite pattern).
Trainers hold a single CallbackHandler and call event methods on it;
the handler invokes each registered callback in registration order.
Functions:
| Name | Description |
|---|---|
add | |
on_train_start | |
on_train_end | |
on_epoch_start | |
on_epoch_end | |
on_step_start | |
on_step_end | |
on_eval_start | |
on_eval_end | |
on_save | |
on_log |
Attributes:
| Name | Type | Description |
|---|---|---|
callbacks | List[TrainingCallback] |
Source code in skyrl/train/utils/callbacks.py:117-162
class CallbackHandler(TrainingCallback):
"""Fan-out dispatcher. Itself a ``TrainingCallback`` (composite pattern).
Trainers hold a single ``CallbackHandler`` and call event methods on it;
the handler invokes each registered callback in registration order.
"""
def __init__(self, callbacks: Optional[List[TrainingCallback]] = None):
self.callbacks: List[TrainingCallback] = list(callbacks or [])
def add(self, callback: TrainingCallback) -> None:
self.callbacks.append(callback)
def _dispatch(self, name: str, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
for cb in self.callbacks:
getattr(cb, name)(trainer, callback_input, control)
def on_train_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_train_start", trainer, callback_input, control)
def on_train_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_train_end", trainer, callback_input, control)
def on_epoch_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_epoch_start", trainer, callback_input, control)
def on_epoch_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_epoch_end", trainer, callback_input, control)
def on_step_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_step_start", trainer, callback_input, control)
def on_step_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_step_end", trainer, callback_input, control)
def on_eval_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_eval_start", trainer, callback_input, control)
def on_eval_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_eval_end", trainer, callback_input, control)
def on_save(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_save", trainer, callback_input, control)
def on_log(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_log", trainer, callback_input, control)attr callbacks
callbacks: List[TrainingCallback] = list(callbacks or [])method add
add(callback: TrainingCallback) -> NoneSource code in skyrl/train/utils/callbacks.py:127-128
def add(self, callback: TrainingCallback) -> None:
self.callbacks.append(callback)method on_train_start
on_train_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:134-135
def on_train_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_train_start", trainer, callback_input, control)method on_train_end
on_train_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:137-138
def on_train_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_train_end", trainer, callback_input, control)method on_epoch_start
on_epoch_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:140-141
def on_epoch_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_epoch_start", trainer, callback_input, control)method on_epoch_end
on_epoch_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:143-144
def on_epoch_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_epoch_end", trainer, callback_input, control)method on_step_start
on_step_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:146-147
def on_step_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_step_start", trainer, callback_input, control)method on_step_end
on_step_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:149-150
def on_step_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_step_end", trainer, callback_input, control)method on_eval_start
on_eval_start(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:152-153
def on_eval_start(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_eval_start", trainer, callback_input, control)method on_eval_end
on_eval_end(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:155-156
def on_eval_end(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_eval_end", trainer, callback_input, control)method on_save
on_save(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:158-159
def on_save(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_save", trainer, callback_input, control)method on_log
on_log(trainer, callback_input: CallbackInput, control: TrainingControl) -> NoneSource code in skyrl/train/utils/callbacks.py:161-162
def on_log(self, trainer, callback_input: CallbackInput, control: TrainingControl) -> None:
self._dispatch("on_log", trainer, callback_input, control)