Backends
Backend abstractions for training and inference.
Abstract Backend
The base class all backends implement.
class AbstractBackend
AbstractBackend(base_model: str, config: BaseModel)Bases: ABC
Abstract base class for TinkerEngine backends.
Backends handle computation and model state manipulation. Database operations are handled by TinkerEngine.
Functions:
| Name | Description |
|---|---|
create_model | Create a new model in the backend. |
forward_backward | Run forward and backward pass on a batch. |
forward | Run forward-only pass on a batch (no gradient computation). |
optim_step | Apply an optimizer step using accumulated gradients. |
sample | Generate samples for a batch of requests. |
save_checkpoint | Save training checkpoint to disk. |
load_checkpoint | Load training checkpoint from disk. |
save_sampler_checkpoint | Prepare model weights for sampling and optionally save to disk. |
has_model | Check if a model is registered with the backend. |
delete_model | Delete a model and free all associated resources. |
Initialize the backend.
Source code in skyrl/backends/backend.py:33-171
class AbstractBackend(ABC):
"""Abstract base class for TinkerEngine backends.
Backends handle computation and model state manipulation.
Database operations are handled by TinkerEngine.
"""
@abstractmethod
def __init__(self, base_model: str, config: BaseModel):
"""Initialize the backend."""
pass
@abstractmethod
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
"""Create a new model in the backend.
Creates optimizer and configures LoRA adapter.
Args:
model_id: The model identifier
lora_config: LoRA configuration with rank and alpha
"""
pass
@abstractmethod
def forward_backward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
"""Run forward and backward pass on a batch.
Args:
prepared_batch: PreparedModelPassBatch with all data extracted from requests
Returns:
Dict mapping request_id to result or error
"""
pass
@abstractmethod
def forward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
"""Run forward-only pass on a batch (no gradient computation).
Args:
prepared_batch: PreparedModelPassBatch with all data extracted from requests
Returns:
Dict mapping request_id to result or error
"""
pass
@abstractmethod
def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
"""Apply an optimizer step using accumulated gradients.
Args:
model_id: The model identifier
request_data: The optimizer step input parameters
Returns:
OptimStepOutput result
"""
pass
@abstractmethod
def sample(
self,
prepared_batch: types.PreparedSampleBatch,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
"""Generate samples for a batch of requests.
Args:
prepared_batch: PreparedSampleBatch with all data extracted from requests
Returns:
Dict mapping request_id to result or error
"""
pass
@abstractmethod
def save_checkpoint(self, output_path, model_id: str) -> None:
"""Save training checkpoint to disk.
Args:
output_path: Path to save the checkpoint
model_id: The model identifier
"""
pass
@abstractmethod
def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
"""Load training checkpoint from disk.
Args:
checkpoint_path: Path to the checkpoint file
model_id: The model identifier
"""
pass
@abstractmethod
def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
"""Prepare model weights for sampling and optionally save to disk.
Backends that use colocated inference engines should sync weights
in-memory regardless of ``persist``. When ``persist`` is *False*
the backend may skip the expensive disk write and only place a
lightweight marker at ``output_path``.
Args:
output_path: Path to save the checkpoint tar.gz file
model_id: The model identifier
persist: If True, write a full model snapshot to disk.
If False, only sync weights in-memory (hot path).
"""
pass
@abstractmethod
def has_model(self, model_id: str) -> bool:
"""Check if a model is registered with the backend.
Args:
model_id: The model identifier
Returns:
True if the model is registered, False otherwise
"""
pass
@abstractmethod
def delete_model(self, model_id: str) -> None:
"""Delete a model and free all associated resources.
Args:
model_id: The model identifier
"""
passmethod abstractmethod create_model
create_model(model_id: str, lora_config: types.LoraConfig) -> NoneCreate a new model in the backend.
Creates optimizer and configures LoRA adapter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_id | str | The model identifier | required |
lora_config | LoraConfig | LoRA configuration with rank and alpha | required |
Source code in skyrl/backends/backend.py:45-55
@abstractmethod
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
"""Create a new model in the backend.
Creates optimizer and configures LoRA adapter.
Args:
model_id: The model identifier
lora_config: LoRA configuration with rank and alpha
"""
passmethod abstractmethod forward_backward
forward_backward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]Run forward and backward pass on a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prepared_batch | PreparedModelPassBatch | PreparedModelPassBatch with all data extracted from requests | required |
Returns:
| Type | Description |
|---|---|
| dict[str, ForwardBackwardOutput | ErrorResponse] |
Source code in skyrl/backends/backend.py:57-70
@abstractmethod
def forward_backward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
"""Run forward and backward pass on a batch.
Args:
prepared_batch: PreparedModelPassBatch with all data extracted from requests
Returns:
Dict mapping request_id to result or error
"""
passmethod abstractmethod forward
forward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]Run forward-only pass on a batch (no gradient computation).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prepared_batch | PreparedModelPassBatch | PreparedModelPassBatch with all data extracted from requests | required |
Returns:
| Type | Description |
|---|---|
| dict[str, ForwardBackwardOutput | ErrorResponse] |
Source code in skyrl/backends/backend.py:72-85
@abstractmethod
def forward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
"""Run forward-only pass on a batch (no gradient computation).
Args:
prepared_batch: PreparedModelPassBatch with all data extracted from requests
Returns:
Dict mapping request_id to result or error
"""
passmethod abstractmethod optim_step
optim_step(model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutputApply an optimizer step using accumulated gradients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_id | str | The model identifier | required |
request_data | OptimStepInput | The optimizer step input parameters | required |
Returns:
| Type | Description |
|---|---|
| OptimStepOutput | OptimStepOutput result |
Source code in skyrl/backends/backend.py:87-98
@abstractmethod
def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
"""Apply an optimizer step using accumulated gradients.
Args:
model_id: The model identifier
request_data: The optimizer step input parameters
Returns:
OptimStepOutput result
"""
passmethod abstractmethod sample
sample(prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]Generate samples for a batch of requests.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prepared_batch | PreparedSampleBatch | PreparedSampleBatch with all data extracted from requests | required |
Returns:
| Type | Description |
|---|---|
| dict[str, SampleOutput | ErrorResponse] |
Source code in skyrl/backends/backend.py:100-113
@abstractmethod
def sample(
self,
prepared_batch: types.PreparedSampleBatch,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
"""Generate samples for a batch of requests.
Args:
prepared_batch: PreparedSampleBatch with all data extracted from requests
Returns:
Dict mapping request_id to result or error
"""
passmethod abstractmethod save_checkpoint
save_checkpoint(output_path, model_id: str) -> NoneSave training checkpoint to disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_path | Path to save the checkpoint | required | |
model_id | str | The model identifier | required |
Source code in skyrl/backends/backend.py:115-123
@abstractmethod
def save_checkpoint(self, output_path, model_id: str) -> None:
"""Save training checkpoint to disk.
Args:
output_path: Path to save the checkpoint
model_id: The model identifier
"""
passmethod abstractmethod load_checkpoint
load_checkpoint(checkpoint_path, model_id: str) -> NoneLoad training checkpoint from disk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path | Path to the checkpoint file | required | |
model_id | str | The model identifier | required |
Source code in skyrl/backends/backend.py:125-133
@abstractmethod
def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
"""Load training checkpoint from disk.
Args:
checkpoint_path: Path to the checkpoint file
model_id: The model identifier
"""
passmethod abstractmethod save_sampler_checkpoint
save_sampler_checkpoint(output_path, model_id: str, persist: bool = True) -> NonePrepare model weights for sampling and optionally save to disk.
Backends that use colocated inference engines should sync weights
in-memory regardless of persist. When persist is False
the backend may skip the expensive disk write and only place a
lightweight marker at output_path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_path | Path to save the checkpoint tar.gz file | required | |
model_id | str | The model identifier | required |
persist | bool | If True, write a full model snapshot to disk. If False, only sync weights in-memory (hot path). | True |
Source code in skyrl/backends/backend.py:135-150
@abstractmethod
def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
"""Prepare model weights for sampling and optionally save to disk.
Backends that use colocated inference engines should sync weights
in-memory regardless of ``persist``. When ``persist`` is *False*
the backend may skip the expensive disk write and only place a
lightweight marker at ``output_path``.
Args:
output_path: Path to save the checkpoint tar.gz file
model_id: The model identifier
persist: If True, write a full model snapshot to disk.
If False, only sync weights in-memory (hot path).
"""
passmethod abstractmethod has_model
has_model(model_id: str) -> boolCheck if a model is registered with the backend.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_id | str | The model identifier | required |
Returns:
| Type | Description |
|---|---|
| bool | True if the model is registered, False otherwise |
Source code in skyrl/backends/backend.py:152-162
@abstractmethod
def has_model(self, model_id: str) -> bool:
"""Check if a model is registered with the backend.
Args:
model_id: The model identifier
Returns:
True if the model is registered, False otherwise
"""
passmethod abstractmethod delete_model
delete_model(model_id: str) -> NoneDelete a model and free all associated resources.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_id | str | The model identifier | required |
Source code in skyrl/backends/backend.py:164-171
@abstractmethod
def delete_model(self, model_id: str) -> None:
"""Delete a model and free all associated resources.
Args:
model_id: The model identifier
"""
passJAX Backend
JAX-based backend for training and inference.
class JaxBackend
JaxBackend(base_model: str, config: JaxBackendConfig)Bases: JaxBackendImpl
Distributed wrapper that broadcasts commands before calling JaxBackendImpl methods.
Workers use runtime type introspection to re-hydrate arguments automatically.
Functions:
| Name | Description |
|---|---|
has_model | Check if a model is registered with the backend. |
delete_model | Delete a model and free all associated resources. |
load_sampler_checkpoint | Insert sampler weights from checkpoint file. |
load_sampler_weights | Load sampler weights for all requests and return adapter indices array. |
create_model | |
forward_backward | |
forward | |
optim_step | |
sample | |
save_checkpoint | |
load_checkpoint | |
save_sampler_checkpoint |
Attributes:
| Name | Type | Description |
|---|---|---|
base_model | ||
config | ||
metrics | ||
tokenizer | ||
model_config | ||
mesh | ||
model | ||
accumulated_grads | ||
optimizers | dict[str, Optimizer] | |
models | dict[str, ModelMetadata] | |
process_id |
Source code in skyrl/backends/jax.py:1069-1132
class JaxBackend(JaxBackendImpl):
"""Distributed wrapper that broadcasts commands before calling JaxBackendImpl methods.
Workers use runtime type introspection to re-hydrate arguments automatically.
"""
def __init__(self, base_model: str, config: JaxBackendConfig):
self.process_id = 0 # Coordinator is always process 0
if config.coordinator_address is not None:
jax.distributed.initialize(
coordinator_address=config.coordinator_address,
num_processes=config.num_processes,
process_id=self.process_id,
)
logger.info(
f"JAX distributed initialized: process_id={self.process_id} ({jax.process_count()} total), "
f"local devices: {jax.local_device_count()}, total devices: {jax.device_count()}"
)
self._broadcast_and_call("__init__", base_model=base_model, config=config, process_id=self.process_id)
def _broadcast_and_call(self, method: str, **kwargs):
"""Broadcast method call to workers and execute locally via super()."""
if jax.process_count() > 1:
hints = get_type_hints(getattr(JaxBackendImpl, method))
# TODO: Remove AnyPath special case once https://github.com/drivendataorg/cloudpathlib/issues/537 is released
def serialize(k, v):
if hints.get(k) is AnyPath:
return str(v)
return TypeAdapter(hints[k]).dump_python(v, mode="json") if k in hints else v
_broadcast_command(
RpcPayload(method=method, kwargs={k: serialize(k, v) for k, v in kwargs.items()}),
process_id=self.process_id,
)
return getattr(super(), method)(**kwargs)
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config)
def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch)
def forward(self, prepared_batch: types.PreparedModelPassBatch):
return self._broadcast_and_call("forward", prepared_batch=prepared_batch)
def optim_step(self, model_id: str, request_data: types.OptimStepInput):
return self._broadcast_and_call("optim_step", model_id=model_id, request_data=request_data)
def sample(self, prepared_batch: types.PreparedSampleBatch):
return self._broadcast_and_call("sample", prepared_batch=prepared_batch)
def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None:
self._broadcast_and_call("save_checkpoint", output_path=output_path, model_id=model_id)
def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None:
self._broadcast_and_call("load_checkpoint", checkpoint_path=checkpoint_path, model_id=model_id)
def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None:
# Write probe so workers can detect shared filesystem
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.with_name(output_path.name + ".probe").write_text("write_probe")
self._broadcast_and_call("save_sampler_checkpoint", output_path=output_path, model_id=model_id, persist=persist)attr base_model
base_model = base_modelattr config
config = configattr property metrics
metrics = types.EngineMetrics()attr tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)model_config
model_config = Qwen3Config(base_config, max_lora_adapters=(config.max_lora_adapters), max_lora_rank=(config.max_lora_rank), shard_attention_heads=(config.shard_attention_heads), loss_chunk_size=(config.loss_chunk_size), gradient_checkpointing=(config.gradient_checkpointing), mhc_expansion_rate=(config.mhc_expansion_rate))mesh
mesh = jax.make_mesh((config.fully_sharded_data_parallel_size, config.expert_parallel_size, config.tensor_parallel_size), ('fsdp', 'ep', 'tp'), axis_types=((jax.sharding.AxisType.Auto,) * 3))attr model
model = model_class(self.model_config, dtype=(get_dtype(self.model_config.get_config().dtype)), rngs=(nnx.Rngs(0)))accumulated_grads
accumulated_grads = AccumulatedGradients.create(self.lora_params, config.max_lora_adapters)optimizers
optimizers: dict[str, nnx.Optimizer] = {}models
models: dict[str, types.ModelMetadata] = {}method abstractmethod has_model
has_model(model_id: str) -> boolCheck if a model is registered with the backend.
method abstractmethod delete_model
delete_model(model_id: str) -> NoneDelete a model and free all associated resources.
load_sampler_checkpoint
load_sampler_checkpoint(model_id: str, checkpoint_id: str, checkpoint_path: AnyPath) -> NoneInsert sampler weights from checkpoint file.
load_sampler_weights
load_sampler_weights(prepared_batch: types.PreparedSampleBatch) -> list[int]Load sampler weights for all requests and return adapter indices array.
Ensures all required checkpoints are loaded before sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prepared_batch | PreparedSampleBatch | PreparedSampleBatch with model_ids, checkpoint_ids, and other batch data | required |
Returns:
| Type | Description |
|---|---|
| list[int] | The adapter_indices array for LoRA sampling [batch_size] |
| list[int] | Uses adapter index 0 for base model sampling (no LoRA) |
attr process_id
process_id = 0method abstractmethod create_model
create_model(model_id: str, lora_config: types.LoraConfig) -> NoneSource code in skyrl/backends/jax.py:1107-1108
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config)method abstractmethod forward_backward
forward_backward(prepared_batch: types.PreparedModelPassBatch)Source code in skyrl/backends/jax.py:1110-1111
def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch)method abstractmethod forward
forward(prepared_batch: types.PreparedModelPassBatch)Source code in skyrl/backends/jax.py:1113-1114
def forward(self, prepared_batch: types.PreparedModelPassBatch):
return self._broadcast_and_call("forward", prepared_batch=prepared_batch)method abstractmethod optim_step
optim_step(model_id: str, request_data: types.OptimStepInput)Source code in skyrl/backends/jax.py:1116-1117
def optim_step(self, model_id: str, request_data: types.OptimStepInput):
return self._broadcast_and_call("optim_step", model_id=model_id, request_data=request_data)method abstractmethod sample
sample(prepared_batch: types.PreparedSampleBatch)Source code in skyrl/backends/jax.py:1119-1120
def sample(self, prepared_batch: types.PreparedSampleBatch):
return self._broadcast_and_call("sample", prepared_batch=prepared_batch)method abstractmethod save_checkpoint
save_checkpoint(output_path: AnyPath, model_id: str) -> NoneSource code in skyrl/backends/jax.py:1122-1123
def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None:
self._broadcast_and_call("save_checkpoint", output_path=output_path, model_id=model_id)method abstractmethod load_checkpoint
load_checkpoint(checkpoint_path: AnyPath, model_id: str) -> NoneSource code in skyrl/backends/jax.py:1125-1126
def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None:
self._broadcast_and_call("load_checkpoint", checkpoint_path=checkpoint_path, model_id=model_id)method abstractmethod save_sampler_checkpoint
save_sampler_checkpoint(output_path: AnyPath, model_id: str, persist: bool = True) -> NoneSource code in skyrl/backends/jax.py:1128-1132
def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None:
# Write probe so workers can detect shared filesystem
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.with_name(output_path.name + ".probe").write_text("write_probe")
self._broadcast_and_call("save_sampler_checkpoint", output_path=output_path, model_id=model_id, persist=persist)SkyRL-Train Backend
Backend wrapping the SkyRL-Train training pipeline.
class SkyRLTrainBackend
SkyRLTrainBackend(base_model: str, config: SkyRLTrainBackendOverrides)Bases: AbstractBackend
SkyRL-Train backend for supervised training.
Functions:
| Name | Description |
|---|---|
has_model | |
build_models | |
init_weight_sync_state | Setup the connection between policy model and inference engine for weight syncing. |
create_model | |
delete_model | |
forward_backward | |
forward | |
optim_step | |
sample | Generate samples using InferenceEngineClient. |
save_checkpoint | Save full training checkpoint (model + optimizer + scheduler) as tar. |
load_checkpoint | Load full training checkpoint (model + optimizer + scheduler) from tar. |
save_sampler_checkpoint | Sync weights to colocated inference engines and optionally save to disk. |
Attributes:
| Name | Type | Description |
|---|---|---|
base_model | ||
config |
Source code in skyrl/backends/skyrl_train_backend.py:96-712
class SkyRLTrainBackend(AbstractBackend):
"""SkyRL-Train backend for supervised training."""
def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides):
logger.warning("=" * 80)
logger.warning("SkyRLTrainBackend is currently EXPERIMENTAL!")
logger.warning("=" * 80)
if ray is None:
raise ImportError(
"SkyRLTrainBackend requires `ray`. Install the appropriate extras (e.g. `--extra skyrl_train`)."
)
self.base_model = base_model
# NOTE: We currently have two config attributes "config" which is just config overrides and "_cfg" which is the actual config object. This is a temporary state given that the Tinker engine expects a .config attribute
self.config = config
self._model_id: str | None = None
self._model_metadata: types.ModelMetadata | None = None
self._cfg = None
self._dispatch: WorkerDispatch | None = None
self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model)
self._inference_engine_client = None
self._inference_engines_initialized = False
def has_model(self, model_id: str) -> bool:
return self._model_id == model_id
def build_models(self, PolicyWorker):
cfg = self._cfg
colocate_all = cfg.trainer.placement.colocate_all
pg = self._colocate_pg
if colocate_all:
assert pg is not None, "placement group must be created when colocate_all=True"
num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
num_rollout_gpus = (
cfg.generator.inference_engine.num_engines
* cfg.generator.inference_engine.tensor_parallel_size
* cfg.generator.inference_engine.pipeline_parallel_size
* cfg.generator.inference_engine.data_parallel_size
)
assert (
num_policy_gpus == num_rollout_gpus
), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.2 if colocate_all else 1,
colocate_all=colocate_all,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
)
# set to a large number for megatron scheduler init
# lr will be managed externally via set_lr()
policy_num_training_steps = 1e9
ray.get(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id))
if colocate_all:
policy_model.offload_to_cpu()
# Create unified dispatch that manages all actor groups
self._dispatch = WorkerDispatch(
cfg=cfg,
policy_actor_group=policy_model,
inference_engine_client=self._inference_engine_client,
)
# Mark all models as offloaded
if colocate_all:
self._dispatch.mark_all_offloaded()
logger.info("init policy model done")
def init_weight_sync_state(self):
"""
Setup the connection between policy model and inference engine for weight syncing.
"""
self._dispatch.init_weight_sync_state(self._inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")
def _ensure_inference_engines(self):
"""Lazily create inference engines and init weight sync on first sampling-related call."""
if self._inference_engines_initialized:
return
logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines")
self._inference_engine_client = InferenceEngineClient(
create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer),
self._tokenizer,
self._cfg.trainer.policy.model.path,
self._cfg.trainer.policy.model.lora,
self._cfg.generator.inference_engine,
)
self._dispatch.set_inference_engine_client(self._inference_engine_client)
self.init_weight_sync_state()
self._inference_engines_initialized = True
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
if self._model_id is not None:
raise ValueError(f"Model '{self._model_id}' already exists. Only one model supported.")
# Build config
self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config)
if not ray.is_initialized():
logger.info("Initializing Ray with runtime environment")
initialize_ray(self._cfg)
# Create shared placement group only when colocating training + inference
if self._cfg.trainer.placement.colocate_all:
self._colocate_pg = self._create_colocate_pg()
else:
self._colocate_pg = None
# Get worker types based on strategy
if self._cfg.trainer.strategy in ("fsdp", "fsdp2"):
from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
elif self._cfg.trainer.strategy == "megatron":
from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
PolicyWorker,
)
else:
raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}")
logger.info("Building models.")
self.build_models(PolicyWorker)
self._model_id = model_id
self._model_metadata = types.ModelMetadata(adapter_index=0, lora_config=lora_config)
logger.info(f"Created model {model_id} using RayPPOTrainer")
def _create_colocate_pg(self):
"""Create a placement group for colocated training + inference."""
ie_cfg = self._cfg.generator.inference_engine
per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count
logger.info(f"Creating placement group with {total_gpu_slots} GPU slots for colocated training+inference")
pg = placement_group([{"GPU": 1, "CPU": 1}] * total_gpu_slots, strategy="PACK")
logger.info("Waiting for placement group to be ready...")
get_ray_pg_ready_with_timeout(pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
logger.info("Placement group ready!")
return pg
def delete_model(self, model_id: str) -> None:
if self._model_id != model_id:
raise ValueError(f"Model {model_id} not found")
# TODO: For now, prefer shutting down the backend and re-launching. Will be improved shortly.
raise NotImplementedError("Deleting models not yet implemented")
def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> TrainingInputBatch:
"""Convert PreparedModelPassBatch to TrainingInputBatch."""
if not prepared_batch.all_input_ids:
return TrainingInputBatch({})
# SkyRL-Train shifts internally, so provide the full sequence length by
# appending the last target token to each already-shifted input.
full_sequences = [
list(input_ids) + ([targets[-1]] if targets else [])
for input_ids, targets in zip(prepared_batch.all_input_ids, prepared_batch.all_targets)
]
max_seq_len = max(len(seq) for seq in full_sequences)
max_response_len = max(len(weights) for weights in prepared_batch.all_token_weights)
sequences, attention_masks, loss_masks, response_masks = [], [], [], []
action_log_probs_list, advantages_list = [], []
for seq, weights, logprobs, advs in zip(
full_sequences,
prepared_batch.all_token_weights,
prepared_batch.all_sampling_logprobs,
prepared_batch.all_advantages,
):
pad_len = max_seq_len - len(seq)
sequences.append([self._tokenizer.pad_token_id] * pad_len + list(seq))
attention_masks.append([0] * pad_len + [1] * len(seq))
action_pad = max_response_len - len(weights)
loss_masks.append([0.0] * action_pad + [float(w) for w in weights])
response_masks.append([0] * action_pad + [1] * len(weights))
action_log_probs_list.append([0.0] * action_pad + [float(lp) for lp in logprobs])
advantages_list.append([0.0] * action_pad + [float(a) for a in advs])
sequences_tensor = torch.tensor(sequences, dtype=torch.long)
attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long)
loss_mask_tensor = torch.tensor(loss_masks, dtype=torch.float32)
response_mask_tensor = torch.tensor(response_masks, dtype=torch.long)
batch_dict = {
"sequences": sequences_tensor,
"attention_mask": attention_mask_tensor,
"loss_mask": loss_mask_tensor,
"response_mask": response_mask_tensor,
}
# Include RL fields (action_log_probs, advantages) when data is present
has_logprobs = any(len(lp) > 0 for lp in prepared_batch.all_sampling_logprobs)
has_advantages = any(len(a) > 0 for a in prepared_batch.all_advantages)
if has_logprobs:
batch_dict["action_log_probs"] = torch.tensor(action_log_probs_list, dtype=torch.float32)
if has_advantages:
batch_dict["advantages"] = torch.tensor(advantages_list, dtype=torch.float32)
batch = TrainingInputBatch(batch_dict)
batch.metadata = {"response_length": max_response_len}
return batch
def _pad_batch(
self, batch: TrainingInputBatch, micro_batch_size: int | None = None
) -> tuple[TrainingInputBatch, int]:
"""Pad the batch so its size is divisible by the required alignment.
The dispatch layer splits the batch evenly across DP workers, so the
batch size must be a multiple of dp_size. When *micro_batch_size* is
given (needed for the Megatron backend whose ``forward_backward_func``
doesn't support ragged micro-batches), we align to
``dp_size * micro_batch_size`` so each per-worker shard is also evenly
divisible by *micro_batch_size*.
Returns:
(padded_batch, pad_size)
"""
dp_size = self._dispatch.get_lcm_dp_size()
alignment = dp_size * micro_batch_size if micro_batch_size else dp_size
pad_size = (alignment - batch.batch_size % alignment) % alignment
if pad_size == 0:
return batch, 0
new_tensors = {}
for key, tensor in batch.items():
if tensor is not None:
if key == "loss_mask":
# Padding entries must not contribute to the loss
additional_dims = tensor.shape[1:]
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
else:
# Clone real data so shapes/dtypes are valid for the model
padding_tensor = tensor[torch.arange(pad_size) % tensor.shape[0]].clone()
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
padded = TrainingInputBatch(new_tensors)
padded.metadata = batch.metadata
logger.info(f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (alignment={alignment})")
return padded, pad_size
def _extract_metrics(self, data: dict) -> dict[str, float]:
"""Extract training metrics from dispatch return dict.
Workers return metrics like 'loss', 'policy_loss', 'policy_entropy', etc.
We convert to Tinker's colon-suffixed format (e.g. 'total_loss:sum').
"""
metrics: dict[str, float] = {}
# SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss'
if "loss" in data:
metrics["total_loss:sum"] = float(data["loss"])
elif "final_loss" in data:
metrics["total_loss:sum"] = float(data["final_loss"])
if "policy_loss" in data:
metrics["pg_loss:sum"] = float(data["policy_loss"])
if "policy_entropy" in data:
metrics["entropy_loss:sum"] = float(data["policy_entropy"])
if "response_length" in data:
metrics["num_tokens:sum"] = float(data["response_length"])
return metrics
def _sleep_inference_engines(self):
"""Sleep inference engines to free GPU memory for training."""
if self._inference_engines_initialized and self._cfg.trainer.placement.colocate_all:
asyncio.run(self._inference_engine_client.sleep())
def forward_backward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
if not prepared_batch.all_input_ids:
return {}
self._sleep_inference_engines()
batch = self._to_training_batch(prepared_batch)
micro_bs = (
self._cfg.trainer.micro_train_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
)
batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)
loss_fn = prepared_batch.all_loss_fns[0]
if len(set(prepared_batch.all_loss_fns)) > 1:
logger.warning(
"SkyRL backend received mixed loss functions %s in one batch; using '%s' for all",
set(prepared_batch.all_loss_fns),
loss_fn,
)
loss_fn_config = next((c for c in prepared_batch.all_loss_fn_configs if c is not None), None)
data = self._dispatch.forward_backward(
"policy",
batch,
loss_fn=loss_fn,
loss_fn_config=loss_fn_config,
)
# Trim padding entries from loss_fn_outputs
if pad_size > 0 and "loss_fn_outputs" in data:
data["loss_fn_outputs"] = data["loss_fn_outputs"][:-pad_size]
metrics = self._extract_metrics(data)
results = {}
for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
raw_output = data["loss_fn_outputs"][i]
logprobs = list(raw_output.get("logprobs", []))
elementwise_loss = list(raw_output.get("elementwise_loss", []))
loss_fn_outputs.append(
{
"elementwise_loss": {
"data": elementwise_loss,
"dtype": "float32",
"shape": [len(elementwise_loss)],
},
"logprobs": {
"data": logprobs,
"dtype": "float32",
"shape": [len(logprobs)],
},
}
)
results[request_id] = types.ForwardBackwardOutput(
loss_fn_output_type="scalar",
loss_fn_outputs=loss_fn_outputs,
metrics=metrics,
)
return results
def forward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
if not prepared_batch.all_input_ids:
return {}
self._sleep_inference_engines()
batch = self._to_training_batch(prepared_batch)
micro_bs = (
self._cfg.trainer.micro_forward_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
)
batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)
data = self._dispatch.forward("policy", batch)
# dispatch.forward() returns TrainingOutputBatch({"output": tensor[batch, max_response_len]})
# Trim padding entries from output
output_logprobs = data["output"]
if pad_size > 0:
output_logprobs = output_logprobs[:-pad_size]
results = {}
for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
# Use token weights length to determine each example's actual response length
valid_len = len(prepared_batch.all_token_weights[i])
start = max(output_logprobs.shape[1] - valid_len, 0)
logprobs = output_logprobs[i, start:].tolist()
loss_fn_outputs.append(
{
"logprobs": {
"data": logprobs,
"dtype": "float32",
"shape": [len(logprobs)],
},
}
)
results[request_id] = types.ForwardBackwardOutput(
loss_fn_output_type="scalar",
loss_fn_outputs=loss_fn_outputs,
metrics={},
)
return results
def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
# Apply learning rate from AdamParams before optimizer step
# Note: beta1, beta2, eps are fixed at optimizer creation and cannot be changed dynamically
adam_params = request_data.adam_params
self._dispatch.set_lr("policy", adam_params.learning_rate)
grad_norm = self._dispatch.optim_step("policy")
logger.info(f"optim_step: lr={adam_params.learning_rate}, grad_norm={grad_norm}")
metrics: dict[str, float] = {}
if grad_norm is not None:
metrics["skyrl.ai/grad_norm"] = float(grad_norm)
metrics["skyrl.ai/learning_rate"] = adam_params.learning_rate
return types.OptimStepOutput(metrics=metrics)
def sample(
self,
prepared_batch: types.PreparedSampleBatch,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
"""Generate samples using InferenceEngineClient.
NOTE: Weight sync is NOT triggered automatically. The caller must call
save_weights_for_sampler() explicitly before calling sample() if weights
have been updated.
"""
# 1. Ensure inference engines are initialized
self._ensure_inference_engines()
# 2. Validate single model
unique_models = set(prepared_batch.all_model_ids)
if unique_models != {self._model_id}:
error = types.ErrorResponse(
error=f"Model mismatch. Expected {self._model_id}, got {unique_models}", status="error"
)
return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}
# 3. Sample all prompts in parallel
async def sample_all():
tasks = []
for i in range(len(prepared_batch.all_prompts)):
prompt = prepared_batch.all_prompts[i]
sampling_params = prepared_batch.all_sampling_params[i]
# Pass through common fields; only stop needs name translation
# (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids)
params_dict = {
"temperature": sampling_params.temperature,
"max_tokens": sampling_params.max_tokens,
"seed": sampling_params.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
"logprobs": 0,
}
if sampling_params.stop_strings:
params_dict["stop"] = sampling_params.stop_strings
if sampling_params.stop_tokens:
params_dict["stop_token_ids"] = sampling_params.stop_tokens
tasks.append(
self._inference_engine_client.sample(
prompt_token_ids=prompt,
num_samples=1, # Tinker batches multiple samples separately
sampling_params=params_dict,
)
)
return await asyncio.gather(*tasks, return_exceptions=True)
# Backend runs in engine subprocess with no event loop
sample_outputs = asyncio.run(sample_all())
# Note: sample_outputs may contain Exception objects (from return_exceptions=True)
# We preserve these to include error messages in responses
# 4. Aggregate results by request
return self._aggregate_sample_results(prepared_batch, sample_outputs)
def _aggregate_sample_results(
self,
prepared_batch: types.PreparedSampleBatch,
sample_outputs: list,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
"""Convert InferenceEngineClient outputs to Tinker format."""
results = {}
for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices:
sequences = []
has_error = False
error_msg = None
for i in range(start_idx, end_idx):
output = sample_outputs[i]
# Check if sampling failed (Exception or None)
if isinstance(output, Exception):
has_error = True
error_msg = f"Sampling failed for sample {i}: {type(output).__name__}: {str(output)}"
logger.error(error_msg)
break
elif output is None:
has_error = True
error_msg = f"Sampling failed for sample {i}: Unknown error (output is None)"
logger.error(error_msg)
break
# Extract tokens and logprobs
response_tokens = output["response_ids"][0]
response_logprobs = (output.get("response_logprobs") or [[]])[0]
stop_reason_raw = output["stop_reasons"][0]
# Map vLLM stop reason to Tinker format
stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length"
# Ensure logprobs exist (critical for RL)
if response_logprobs is None or len(response_logprobs) == 0:
logger.warning("No logprobs returned - filling with zeros")
response_logprobs = [0.0] * len(response_tokens)
sequences.append(
types.GeneratedSequence(
tokens=response_tokens,
logprobs=response_logprobs,
stop_reason=stop_reason,
)
)
if has_error:
results[request_id] = types.ErrorResponse(
error=error_msg or "Unknown sampling error",
status="error",
)
else:
# Note: prompt_logprobs not supported initially
if needs_prompt_logprobs:
logger.warning("Prompt logprobs requested but not yet supported")
results[request_id] = types.SampleOutput(
sequences=sequences,
prompt_logprobs=None,
)
return results
def _validate_model_state(self, model_id: str) -> None:
"""Validate that model exists and is initialized."""
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
if self._dispatch is None:
raise RuntimeError("Model not initialized")
def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None:
"""Create an uncompressed tar archive from a directory."""
# Ensure parent directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints
with tarfile.open(output_path, "w") as tar:
tar.add(source_dir, arcname=".")
def save_checkpoint(self, output_path, model_id: str) -> None:
"""Save full training checkpoint (model + optimizer + scheduler) as tar."""
self._validate_model_state(model_id)
# Create temp directory for checkpoint
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_dir = os.path.join(temp_dir, "checkpoint")
# Save checkpoint directory (includes optimizer state automatically)
self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer)
# Create tar archive
self._create_tar_from_directory(ckpt_dir, output_path)
logger.info(f"Saved checkpoint for {model_id} to {output_path}")
def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
"""Load full training checkpoint (model + optimizer + scheduler) from tar."""
self._validate_model_state(model_id)
# Extract tar to temp directory (filter='data' prevents path traversal attacks)
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(checkpoint_path, "r") as tar:
tar.extractall(temp_dir, filter="data")
# Load checkpoint (includes optimizer and scheduler states)
self._dispatch.load_checkpoint(
model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True
)
logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}")
def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
"""Sync weights to colocated inference engines and optionally save to disk.
The NCCL broadcast always runs so inference engines have the latest
policy weights. When ``persist`` is False (the common hot-path in RL
loops) the expensive HuggingFace model export is skipped entirely.
"""
self._validate_model_state(model_id)
# Lazily create inference engines on first sampling-related call
self._ensure_inference_engines()
asyncio.run(self._dispatch.save_weights_for_sampler())
logger.info(f"Synced weights for {model_id} to inference engines via NCCL")
if persist:
# TODO(tyler): For LoRA, only save the adapters instead of the full merged model
with tempfile.TemporaryDirectory() as temp_dir:
hf_dir = os.path.join(temp_dir, "model")
self._dispatch.save_hf_model(model="policy", export_dir=hf_dir, tokenizer=self._tokenizer)
self._create_tar_from_directory(hf_dir, output_path)
logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
else:
# Hot path: write a lightweight marker so the engine's checkpoint
# bookkeeping stays consistent. Actual weights live in GPU memory.
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with tarfile.open(output_path, "w"):
pass # empty tar — marker only
logger.info(f"Synced weights for {model_id} (disk save skipped)")attr base_model
base_model = base_modelattr config
config = configmethod abstractmethod has_model
has_model(model_id: str) -> boolSource code in skyrl/backends/skyrl_train_backend.py:120-121
def has_model(self, model_id: str) -> bool:
return self._model_id == model_idmethod build_models
build_models(PolicyWorker)Source code in skyrl/backends/skyrl_train_backend.py:123-178
def build_models(self, PolicyWorker):
cfg = self._cfg
colocate_all = cfg.trainer.placement.colocate_all
pg = self._colocate_pg
if colocate_all:
assert pg is not None, "placement group must be created when colocate_all=True"
num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
num_rollout_gpus = (
cfg.generator.inference_engine.num_engines
* cfg.generator.inference_engine.tensor_parallel_size
* cfg.generator.inference_engine.pipeline_parallel_size
* cfg.generator.inference_engine.data_parallel_size
)
assert (
num_policy_gpus == num_rollout_gpus
), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"
policy_model = PPORayActorGroup(
cfg.trainer,
cfg.trainer.placement.policy_num_nodes,
cfg.trainer.placement.policy_num_gpus_per_node,
PolicyWorker,
pg=pg,
num_gpus_per_actor=0.2 if colocate_all else 1,
colocate_all=colocate_all,
sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
record_memory=cfg.trainer.policy.record_memory,
)
# set to a large number for megatron scheduler init
# lr will be managed externally via set_lr()
policy_num_training_steps = 1e9
ray.get(
policy_model.async_init_model(
cfg.trainer.policy.model.path,
num_training_steps=policy_num_training_steps,
)
)
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id))
if colocate_all:
policy_model.offload_to_cpu()
# Create unified dispatch that manages all actor groups
self._dispatch = WorkerDispatch(
cfg=cfg,
policy_actor_group=policy_model,
inference_engine_client=self._inference_engine_client,
)
# Mark all models as offloaded
if colocate_all:
self._dispatch.mark_all_offloaded()
logger.info("init policy model done")method init_weight_sync_state
init_weight_sync_state()Setup the connection between policy model and inference engine for weight syncing.
Source code in skyrl/backends/skyrl_train_backend.py:180-185
def init_weight_sync_state(self):
"""
Setup the connection between policy model and inference engine for weight syncing.
"""
self._dispatch.init_weight_sync_state(self._inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")method abstractmethod create_model
create_model(model_id: str, lora_config: types.LoraConfig) -> NoneSource code in skyrl/backends/skyrl_train_backend.py:204-236
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
if self._model_id is not None:
raise ValueError(f"Model '{self._model_id}' already exists. Only one model supported.")
# Build config
self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config)
if not ray.is_initialized():
logger.info("Initializing Ray with runtime environment")
initialize_ray(self._cfg)
# Create shared placement group only when colocating training + inference
if self._cfg.trainer.placement.colocate_all:
self._colocate_pg = self._create_colocate_pg()
else:
self._colocate_pg = None
# Get worker types based on strategy
if self._cfg.trainer.strategy in ("fsdp", "fsdp2"):
from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
elif self._cfg.trainer.strategy == "megatron":
from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
PolicyWorker,
)
else:
raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}")
logger.info("Building models.")
self.build_models(PolicyWorker)
self._model_id = model_id
self._model_metadata = types.ModelMetadata(adapter_index=0, lora_config=lora_config)
logger.info(f"Created model {model_id} using RayPPOTrainer")method abstractmethod delete_model
delete_model(model_id: str) -> NoneSource code in skyrl/backends/skyrl_train_backend.py:253-257
def delete_model(self, model_id: str) -> None:
if self._model_id != model_id:
raise ValueError(f"Model {model_id} not found")
# TODO: For now, prefer shutting down the backend and re-launching. Will be improved shortly.
raise NotImplementedError("Deleting models not yet implemented")method abstractmethod forward_backward
forward_backward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]Source code in skyrl/backends/skyrl_train_backend.py:382-443
def forward_backward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
if not prepared_batch.all_input_ids:
return {}
self._sleep_inference_engines()
batch = self._to_training_batch(prepared_batch)
micro_bs = (
self._cfg.trainer.micro_train_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
)
batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)
loss_fn = prepared_batch.all_loss_fns[0]
if len(set(prepared_batch.all_loss_fns)) > 1:
logger.warning(
"SkyRL backend received mixed loss functions %s in one batch; using '%s' for all",
set(prepared_batch.all_loss_fns),
loss_fn,
)
loss_fn_config = next((c for c in prepared_batch.all_loss_fn_configs if c is not None), None)
data = self._dispatch.forward_backward(
"policy",
batch,
loss_fn=loss_fn,
loss_fn_config=loss_fn_config,
)
# Trim padding entries from loss_fn_outputs
if pad_size > 0 and "loss_fn_outputs" in data:
data["loss_fn_outputs"] = data["loss_fn_outputs"][:-pad_size]
metrics = self._extract_metrics(data)
results = {}
for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
raw_output = data["loss_fn_outputs"][i]
logprobs = list(raw_output.get("logprobs", []))
elementwise_loss = list(raw_output.get("elementwise_loss", []))
loss_fn_outputs.append(
{
"elementwise_loss": {
"data": elementwise_loss,
"dtype": "float32",
"shape": [len(elementwise_loss)],
},
"logprobs": {
"data": logprobs,
"dtype": "float32",
"shape": [len(logprobs)],
},
}
)
results[request_id] = types.ForwardBackwardOutput(
loss_fn_output_type="scalar",
loss_fn_outputs=loss_fn_outputs,
metrics=metrics,
)
return resultsmethod abstractmethod forward
forward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]Source code in skyrl/backends/skyrl_train_backend.py:445-488
def forward(
self,
prepared_batch: types.PreparedModelPassBatch,
) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
if not prepared_batch.all_input_ids:
return {}
self._sleep_inference_engines()
batch = self._to_training_batch(prepared_batch)
micro_bs = (
self._cfg.trainer.micro_forward_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
)
batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)
data = self._dispatch.forward("policy", batch)
# dispatch.forward() returns TrainingOutputBatch({"output": tensor[batch, max_response_len]})
# Trim padding entries from output
output_logprobs = data["output"]
if pad_size > 0:
output_logprobs = output_logprobs[:-pad_size]
results = {}
for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
loss_fn_outputs = []
for i in range(start_idx, end_idx):
# Use token weights length to determine each example's actual response length
valid_len = len(prepared_batch.all_token_weights[i])
start = max(output_logprobs.shape[1] - valid_len, 0)
logprobs = output_logprobs[i, start:].tolist()
loss_fn_outputs.append(
{
"logprobs": {
"data": logprobs,
"dtype": "float32",
"shape": [len(logprobs)],
},
}
)
results[request_id] = types.ForwardBackwardOutput(
loss_fn_output_type="scalar",
loss_fn_outputs=loss_fn_outputs,
metrics={},
)
return resultsmethod abstractmethod optim_step
optim_step(model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutputSource code in skyrl/backends/skyrl_train_backend.py:490-506
def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
if model_id != self._model_id:
raise ValueError(f"Model {model_id} not found")
# Apply learning rate from AdamParams before optimizer step
# Note: beta1, beta2, eps are fixed at optimizer creation and cannot be changed dynamically
adam_params = request_data.adam_params
self._dispatch.set_lr("policy", adam_params.learning_rate)
grad_norm = self._dispatch.optim_step("policy")
logger.info(f"optim_step: lr={adam_params.learning_rate}, grad_norm={grad_norm}")
metrics: dict[str, float] = {}
if grad_norm is not None:
metrics["skyrl.ai/grad_norm"] = float(grad_norm)
metrics["skyrl.ai/learning_rate"] = adam_params.learning_rate
return types.OptimStepOutput(metrics=metrics)method abstractmethod sample
sample(prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]Generate samples using InferenceEngineClient.
NOTE: Weight sync is NOT triggered automatically. The caller must call save_weights_for_sampler() explicitly before calling sample() if weights have been updated.
Source code in skyrl/backends/skyrl_train_backend.py:508-568
def sample(
self,
prepared_batch: types.PreparedSampleBatch,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
"""Generate samples using InferenceEngineClient.
NOTE: Weight sync is NOT triggered automatically. The caller must call
save_weights_for_sampler() explicitly before calling sample() if weights
have been updated.
"""
# 1. Ensure inference engines are initialized
self._ensure_inference_engines()
# 2. Validate single model
unique_models = set(prepared_batch.all_model_ids)
if unique_models != {self._model_id}:
error = types.ErrorResponse(
error=f"Model mismatch. Expected {self._model_id}, got {unique_models}", status="error"
)
return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}
# 3. Sample all prompts in parallel
async def sample_all():
tasks = []
for i in range(len(prepared_batch.all_prompts)):
prompt = prepared_batch.all_prompts[i]
sampling_params = prepared_batch.all_sampling_params[i]
# Pass through common fields; only stop needs name translation
# (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids)
params_dict = {
"temperature": sampling_params.temperature,
"max_tokens": sampling_params.max_tokens,
"seed": sampling_params.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
"logprobs": 0,
}
if sampling_params.stop_strings:
params_dict["stop"] = sampling_params.stop_strings
if sampling_params.stop_tokens:
params_dict["stop_token_ids"] = sampling_params.stop_tokens
tasks.append(
self._inference_engine_client.sample(
prompt_token_ids=prompt,
num_samples=1, # Tinker batches multiple samples separately
sampling_params=params_dict,
)
)
return await asyncio.gather(*tasks, return_exceptions=True)
# Backend runs in engine subprocess with no event loop
sample_outputs = asyncio.run(sample_all())
# Note: sample_outputs may contain Exception objects (from return_exceptions=True)
# We preserve these to include error messages in responses
# 4. Aggregate results by request
return self._aggregate_sample_results(prepared_batch, sample_outputs)method abstractmethod save_checkpoint
save_checkpoint(output_path, model_id: str) -> NoneSave full training checkpoint (model + optimizer + scheduler) as tar.
Source code in skyrl/backends/skyrl_train_backend.py:652-666
def save_checkpoint(self, output_path, model_id: str) -> None:
"""Save full training checkpoint (model + optimizer + scheduler) as tar."""
self._validate_model_state(model_id)
# Create temp directory for checkpoint
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_dir = os.path.join(temp_dir, "checkpoint")
# Save checkpoint directory (includes optimizer state automatically)
self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer)
# Create tar archive
self._create_tar_from_directory(ckpt_dir, output_path)
logger.info(f"Saved checkpoint for {model_id} to {output_path}")method abstractmethod load_checkpoint
load_checkpoint(checkpoint_path, model_id: str) -> NoneLoad full training checkpoint (model + optimizer + scheduler) from tar.
Source code in skyrl/backends/skyrl_train_backend.py:668-682
def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
"""Load full training checkpoint (model + optimizer + scheduler) from tar."""
self._validate_model_state(model_id)
# Extract tar to temp directory (filter='data' prevents path traversal attacks)
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(checkpoint_path, "r") as tar:
tar.extractall(temp_dir, filter="data")
# Load checkpoint (includes optimizer and scheduler states)
self._dispatch.load_checkpoint(
model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True
)
logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}")method abstractmethod save_sampler_checkpoint
save_sampler_checkpoint(output_path, model_id: str, persist: bool = True) -> NoneSync weights to colocated inference engines and optionally save to disk.
The NCCL broadcast always runs so inference engines have the latest
policy weights. When persist is False (the common hot-path in RL
loops) the expensive HuggingFace model export is skipped entirely.
Source code in skyrl/backends/skyrl_train_backend.py:684-712
def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
"""Sync weights to colocated inference engines and optionally save to disk.
The NCCL broadcast always runs so inference engines have the latest
policy weights. When ``persist`` is False (the common hot-path in RL
loops) the expensive HuggingFace model export is skipped entirely.
"""
self._validate_model_state(model_id)
# Lazily create inference engines on first sampling-related call
self._ensure_inference_engines()
asyncio.run(self._dispatch.save_weights_for_sampler())
logger.info(f"Synced weights for {model_id} to inference engines via NCCL")
if persist:
# TODO(tyler): For LoRA, only save the adapters instead of the full merged model
with tempfile.TemporaryDirectory() as temp_dir:
hf_dir = os.path.join(temp_dir, "model")
self._dispatch.save_hf_model(model="policy", export_dir=hf_dir, tokenizer=self._tokenizer)
self._create_tar_from_directory(hf_dir, output_path)
logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
else:
# Hot path: write a lightweight marker so the engine's checkpoint
# bookkeeping stays consistent. Actual weights live in GPU memory.
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with tarfile.open(output_path, "w"):
pass # empty tar — marker only
logger.info(f"Synced weights for {model_id} (disk save skipped)")