SkyRL
API ReferenceSkyRL

Backends

Backend abstractions for training and inference.

Abstract Backend

The base class all backends implement.

class AbstractBackend

AbstractBackend(base_model: str, config: BaseModel)

Bases: ABC

Abstract base class for TinkerEngine backends.

Backends handle computation and model state manipulation. Database operations are handled by TinkerEngine.

Functions:

NameDescription
create_modelCreate a new model in the backend.
forward_backwardRun forward and backward pass on a batch.
forwardRun forward-only pass on a batch (no gradient computation).
optim_stepApply an optimizer step using accumulated gradients.
sampleGenerate samples for a batch of requests.
save_checkpointSave training checkpoint to disk.
load_checkpointLoad training checkpoint from disk.
save_sampler_checkpointPrepare model weights for sampling and optionally save to disk.
has_modelCheck if a model is registered with the backend.
delete_modelDelete a model and free all associated resources.

Initialize the backend.

Source code in skyrl/backends/backend.py:33-171
class AbstractBackend(ABC):
    """Abstract base class for TinkerEngine backends.

    Backends handle computation and model state manipulation.
    Database operations are handled by TinkerEngine.
    """

    @abstractmethod
    def __init__(self, base_model: str, config: BaseModel):
        """Initialize the backend."""
        pass

    @abstractmethod
    def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
        """Create a new model in the backend.

        Creates optimizer and configures LoRA adapter.

        Args:
            model_id: The model identifier
            lora_config: LoRA configuration with rank and alpha
        """
        pass

    @abstractmethod
    def forward_backward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        """Run forward and backward pass on a batch.

        Args:
            prepared_batch: PreparedModelPassBatch with all data extracted from requests

        Returns:
            Dict mapping request_id to result or error
        """
        pass

    @abstractmethod
    def forward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        """Run forward-only pass on a batch (no gradient computation).

        Args:
            prepared_batch: PreparedModelPassBatch with all data extracted from requests

        Returns:
            Dict mapping request_id to result or error
        """
        pass

    @abstractmethod
    def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
        """Apply an optimizer step using accumulated gradients.

        Args:
            model_id: The model identifier
            request_data: The optimizer step input parameters

        Returns:
            OptimStepOutput result
        """
        pass

    @abstractmethod
    def sample(
        self,
        prepared_batch: types.PreparedSampleBatch,
    ) -> dict[str, types.SampleOutput | types.ErrorResponse]:
        """Generate samples for a batch of requests.

        Args:
            prepared_batch: PreparedSampleBatch with all data extracted from requests

        Returns:
            Dict mapping request_id to result or error
        """
        pass

    @abstractmethod
    def save_checkpoint(self, output_path, model_id: str) -> None:
        """Save training checkpoint to disk.

        Args:
            output_path: Path to save the checkpoint
            model_id: The model identifier
        """
        pass

    @abstractmethod
    def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
        """Load training checkpoint from disk.

        Args:
            checkpoint_path: Path to the checkpoint file
            model_id: The model identifier
        """
        pass

    @abstractmethod
    def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
        """Prepare model weights for sampling and optionally save to disk.

        Backends that use colocated inference engines should sync weights
        in-memory regardless of ``persist``.  When ``persist`` is *False*
        the backend may skip the expensive disk write and only place a
        lightweight marker at ``output_path``.

        Args:
            output_path: Path to save the checkpoint tar.gz file
            model_id: The model identifier
            persist: If True, write a full model snapshot to disk.
                     If False, only sync weights in-memory (hot path).
        """
        pass

    @abstractmethod
    def has_model(self, model_id: str) -> bool:
        """Check if a model is registered with the backend.

        Args:
            model_id: The model identifier

        Returns:
            True if the model is registered, False otherwise
        """
        pass

    @abstractmethod
    def delete_model(self, model_id: str) -> None:
        """Delete a model and free all associated resources.

        Args:
            model_id: The model identifier
        """
        pass

method abstractmethod create_model

create_model(model_id: str, lora_config: types.LoraConfig) -> None

Create a new model in the backend.

Creates optimizer and configures LoRA adapter.

Parameters:

NameTypeDescriptionDefault
model_idstrThe model identifierrequired
lora_configLoraConfigLoRA configuration with rank and alpharequired
Source code in skyrl/backends/backend.py:45-55
    @abstractmethod
    def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
        """Create a new model in the backend.

        Creates optimizer and configures LoRA adapter.

        Args:
            model_id: The model identifier
            lora_config: LoRA configuration with rank and alpha
        """
        pass

method abstractmethod forward_backward

forward_backward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]

Run forward and backward pass on a batch.

Parameters:

NameTypeDescriptionDefault
prepared_batchPreparedModelPassBatchPreparedModelPassBatch with all data extracted from requestsrequired

Returns:

TypeDescription
dict[str, ForwardBackwardOutputErrorResponse]
Source code in skyrl/backends/backend.py:57-70
    @abstractmethod
    def forward_backward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        """Run forward and backward pass on a batch.

        Args:
            prepared_batch: PreparedModelPassBatch with all data extracted from requests

        Returns:
            Dict mapping request_id to result or error
        """
        pass

method abstractmethod forward

forward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]

Run forward-only pass on a batch (no gradient computation).

Parameters:

NameTypeDescriptionDefault
prepared_batchPreparedModelPassBatchPreparedModelPassBatch with all data extracted from requestsrequired

Returns:

TypeDescription
dict[str, ForwardBackwardOutputErrorResponse]
Source code in skyrl/backends/backend.py:72-85
    @abstractmethod
    def forward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        """Run forward-only pass on a batch (no gradient computation).

        Args:
            prepared_batch: PreparedModelPassBatch with all data extracted from requests

        Returns:
            Dict mapping request_id to result or error
        """
        pass

method abstractmethod optim_step

optim_step(model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput

Apply an optimizer step using accumulated gradients.

Parameters:

NameTypeDescriptionDefault
model_idstrThe model identifierrequired
request_dataOptimStepInputThe optimizer step input parametersrequired

Returns:

TypeDescription
OptimStepOutputOptimStepOutput result
Source code in skyrl/backends/backend.py:87-98
    @abstractmethod
    def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
        """Apply an optimizer step using accumulated gradients.

        Args:
            model_id: The model identifier
            request_data: The optimizer step input parameters

        Returns:
            OptimStepOutput result
        """
        pass

method abstractmethod sample

sample(prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]

Generate samples for a batch of requests.

Parameters:

NameTypeDescriptionDefault
prepared_batchPreparedSampleBatchPreparedSampleBatch with all data extracted from requestsrequired

Returns:

TypeDescription
dict[str, SampleOutputErrorResponse]
Source code in skyrl/backends/backend.py:100-113
    @abstractmethod
    def sample(
        self,
        prepared_batch: types.PreparedSampleBatch,
    ) -> dict[str, types.SampleOutput | types.ErrorResponse]:
        """Generate samples for a batch of requests.

        Args:
            prepared_batch: PreparedSampleBatch with all data extracted from requests

        Returns:
            Dict mapping request_id to result or error
        """
        pass

method abstractmethod save_checkpoint

save_checkpoint(output_path, model_id: str) -> None

Save training checkpoint to disk.

Parameters:

NameTypeDescriptionDefault
output_pathPath to save the checkpointrequired
model_idstrThe model identifierrequired
Source code in skyrl/backends/backend.py:115-123
    @abstractmethod
    def save_checkpoint(self, output_path, model_id: str) -> None:
        """Save training checkpoint to disk.

        Args:
            output_path: Path to save the checkpoint
            model_id: The model identifier
        """
        pass

method abstractmethod load_checkpoint

load_checkpoint(checkpoint_path, model_id: str) -> None

Load training checkpoint from disk.

Parameters:

NameTypeDescriptionDefault
checkpoint_pathPath to the checkpoint filerequired
model_idstrThe model identifierrequired
Source code in skyrl/backends/backend.py:125-133
    @abstractmethod
    def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
        """Load training checkpoint from disk.

        Args:
            checkpoint_path: Path to the checkpoint file
            model_id: The model identifier
        """
        pass

method abstractmethod save_sampler_checkpoint

save_sampler_checkpoint(output_path, model_id: str, persist: bool = True) -> None

Prepare model weights for sampling and optionally save to disk.

Backends that use colocated inference engines should sync weights in-memory regardless of persist. When persist is False the backend may skip the expensive disk write and only place a lightweight marker at output_path.

Parameters:

NameTypeDescriptionDefault
output_pathPath to save the checkpoint tar.gz filerequired
model_idstrThe model identifierrequired
persistboolIf True, write a full model snapshot to disk. If False, only sync weights in-memory (hot path).True
Source code in skyrl/backends/backend.py:135-150
    @abstractmethod
    def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
        """Prepare model weights for sampling and optionally save to disk.

        Backends that use colocated inference engines should sync weights
        in-memory regardless of ``persist``.  When ``persist`` is *False*
        the backend may skip the expensive disk write and only place a
        lightweight marker at ``output_path``.

        Args:
            output_path: Path to save the checkpoint tar.gz file
            model_id: The model identifier
            persist: If True, write a full model snapshot to disk.
                     If False, only sync weights in-memory (hot path).
        """
        pass

method abstractmethod has_model

has_model(model_id: str) -> bool

Check if a model is registered with the backend.

Parameters:

NameTypeDescriptionDefault
model_idstrThe model identifierrequired

Returns:

TypeDescription
boolTrue if the model is registered, False otherwise
Source code in skyrl/backends/backend.py:152-162
    @abstractmethod
    def has_model(self, model_id: str) -> bool:
        """Check if a model is registered with the backend.

        Args:
            model_id: The model identifier

        Returns:
            True if the model is registered, False otherwise
        """
        pass

method abstractmethod delete_model

delete_model(model_id: str) -> None

Delete a model and free all associated resources.

Parameters:

NameTypeDescriptionDefault
model_idstrThe model identifierrequired
Source code in skyrl/backends/backend.py:164-171
    @abstractmethod
    def delete_model(self, model_id: str) -> None:
        """Delete a model and free all associated resources.

        Args:
            model_id: The model identifier
        """
        pass

JAX Backend

JAX-based backend for training and inference.

class JaxBackend

JaxBackend(base_model: str, config: JaxBackendConfig)

Bases: JaxBackendImpl

Distributed wrapper that broadcasts commands before calling JaxBackendImpl methods.

Workers use runtime type introspection to re-hydrate arguments automatically.

Functions:

NameDescription
has_modelCheck if a model is registered with the backend.
delete_modelDelete a model and free all associated resources.
load_sampler_checkpointInsert sampler weights from checkpoint file.
load_sampler_weightsLoad sampler weights for all requests and return adapter indices array.
create_model
forward_backward
forward
optim_step
sample
save_checkpoint
load_checkpoint
save_sampler_checkpoint

Attributes:

Source code in skyrl/backends/jax.py:1069-1132
class JaxBackend(JaxBackendImpl):
    """Distributed wrapper that broadcasts commands before calling JaxBackendImpl methods.

    Workers use runtime type introspection to re-hydrate arguments automatically.
    """

    def __init__(self, base_model: str, config: JaxBackendConfig):
        self.process_id = 0  # Coordinator is always process 0
        if config.coordinator_address is not None:
            jax.distributed.initialize(
                coordinator_address=config.coordinator_address,
                num_processes=config.num_processes,
                process_id=self.process_id,
            )
            logger.info(
                f"JAX distributed initialized: process_id={self.process_id} ({jax.process_count()} total), "
                f"local devices: {jax.local_device_count()}, total devices: {jax.device_count()}"
            )

        self._broadcast_and_call("__init__", base_model=base_model, config=config, process_id=self.process_id)

    def _broadcast_and_call(self, method: str, **kwargs):
        """Broadcast method call to workers and execute locally via super()."""
        if jax.process_count() > 1:
            hints = get_type_hints(getattr(JaxBackendImpl, method))

            # TODO: Remove AnyPath special case once https://github.com/drivendataorg/cloudpathlib/issues/537 is released
            def serialize(k, v):
                if hints.get(k) is AnyPath:
                    return str(v)
                return TypeAdapter(hints[k]).dump_python(v, mode="json") if k in hints else v

            _broadcast_command(
                RpcPayload(method=method, kwargs={k: serialize(k, v) for k, v in kwargs.items()}),
                process_id=self.process_id,
            )
        return getattr(super(), method)(**kwargs)

    def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
        self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config)

    def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
        return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch)

    def forward(self, prepared_batch: types.PreparedModelPassBatch):
        return self._broadcast_and_call("forward", prepared_batch=prepared_batch)

    def optim_step(self, model_id: str, request_data: types.OptimStepInput):
        return self._broadcast_and_call("optim_step", model_id=model_id, request_data=request_data)

    def sample(self, prepared_batch: types.PreparedSampleBatch):
        return self._broadcast_and_call("sample", prepared_batch=prepared_batch)

    def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None:
        self._broadcast_and_call("save_checkpoint", output_path=output_path, model_id=model_id)

    def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None:
        self._broadcast_and_call("load_checkpoint", checkpoint_path=checkpoint_path, model_id=model_id)

    def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None:
        # Write probe so workers can detect shared filesystem
        output_path.parent.mkdir(parents=True, exist_ok=True)
        output_path.with_name(output_path.name + ".probe").write_text("write_probe")
        self._broadcast_and_call("save_sampler_checkpoint", output_path=output_path, model_id=model_id, persist=persist)

attr base_model

base_model = base_model

attr config

config = config

attr property metrics

metrics = types.EngineMetrics()

attr tokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)

model_config

model_config = Qwen3Config(base_config, max_lora_adapters=(config.max_lora_adapters), max_lora_rank=(config.max_lora_rank), shard_attention_heads=(config.shard_attention_heads), loss_chunk_size=(config.loss_chunk_size), gradient_checkpointing=(config.gradient_checkpointing), mhc_expansion_rate=(config.mhc_expansion_rate))

mesh

mesh = jax.make_mesh((config.fully_sharded_data_parallel_size, config.expert_parallel_size, config.tensor_parallel_size), ('fsdp', 'ep', 'tp'), axis_types=((jax.sharding.AxisType.Auto,) * 3))

attr model

model = model_class(self.model_config, dtype=(get_dtype(self.model_config.get_config().dtype)), rngs=(nnx.Rngs(0)))

accumulated_grads

accumulated_grads = AccumulatedGradients.create(self.lora_params, config.max_lora_adapters)

optimizers

optimizers: dict[str, nnx.Optimizer] = {}

models

models: dict[str, types.ModelMetadata] = {}

method abstractmethod has_model

has_model(model_id: str) -> bool

Check if a model is registered with the backend.

method abstractmethod delete_model

delete_model(model_id: str) -> None

Delete a model and free all associated resources.

load_sampler_checkpoint

load_sampler_checkpoint(model_id: str, checkpoint_id: str, checkpoint_path: AnyPath) -> None

Insert sampler weights from checkpoint file.

load_sampler_weights

load_sampler_weights(prepared_batch: types.PreparedSampleBatch) -> list[int]

Load sampler weights for all requests and return adapter indices array.

Ensures all required checkpoints are loaded before sampling.

Parameters:

NameTypeDescriptionDefault
prepared_batchPreparedSampleBatchPreparedSampleBatch with model_ids, checkpoint_ids, and other batch datarequired

Returns:

TypeDescription
list[int]The adapter_indices array for LoRA sampling [batch_size]
list[int]Uses adapter index 0 for base model sampling (no LoRA)

attr process_id

process_id = 0

method abstractmethod create_model

create_model(model_id: str, lora_config: types.LoraConfig) -> None
Source code in skyrl/backends/jax.py:1107-1108
    def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
        self._broadcast_and_call("create_model", model_id=model_id, lora_config=lora_config)

method abstractmethod forward_backward

forward_backward(prepared_batch: types.PreparedModelPassBatch)
Source code in skyrl/backends/jax.py:1110-1111
    def forward_backward(self, prepared_batch: types.PreparedModelPassBatch):
        return self._broadcast_and_call("forward_backward", prepared_batch=prepared_batch)

method abstractmethod forward

forward(prepared_batch: types.PreparedModelPassBatch)
Source code in skyrl/backends/jax.py:1113-1114
    def forward(self, prepared_batch: types.PreparedModelPassBatch):
        return self._broadcast_and_call("forward", prepared_batch=prepared_batch)

method abstractmethod optim_step

optim_step(model_id: str, request_data: types.OptimStepInput)
Source code in skyrl/backends/jax.py:1116-1117
    def optim_step(self, model_id: str, request_data: types.OptimStepInput):
        return self._broadcast_and_call("optim_step", model_id=model_id, request_data=request_data)

method abstractmethod sample

sample(prepared_batch: types.PreparedSampleBatch)
Source code in skyrl/backends/jax.py:1119-1120
    def sample(self, prepared_batch: types.PreparedSampleBatch):
        return self._broadcast_and_call("sample", prepared_batch=prepared_batch)

method abstractmethod save_checkpoint

save_checkpoint(output_path: AnyPath, model_id: str) -> None
Source code in skyrl/backends/jax.py:1122-1123
    def save_checkpoint(self, output_path: AnyPath, model_id: str) -> None:
        self._broadcast_and_call("save_checkpoint", output_path=output_path, model_id=model_id)

method abstractmethod load_checkpoint

load_checkpoint(checkpoint_path: AnyPath, model_id: str) -> None
Source code in skyrl/backends/jax.py:1125-1126
    def load_checkpoint(self, checkpoint_path: AnyPath, model_id: str) -> None:
        self._broadcast_and_call("load_checkpoint", checkpoint_path=checkpoint_path, model_id=model_id)

method abstractmethod save_sampler_checkpoint

save_sampler_checkpoint(output_path: AnyPath, model_id: str, persist: bool = True) -> None
Source code in skyrl/backends/jax.py:1128-1132
    def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist: bool = True) -> None:
        # Write probe so workers can detect shared filesystem
        output_path.parent.mkdir(parents=True, exist_ok=True)
        output_path.with_name(output_path.name + ".probe").write_text("write_probe")
        self._broadcast_and_call("save_sampler_checkpoint", output_path=output_path, model_id=model_id, persist=persist)

SkyRL-Train Backend

Backend wrapping the SkyRL-Train training pipeline.

class SkyRLTrainBackend

SkyRLTrainBackend(base_model: str, config: SkyRLTrainBackendOverrides)

Bases: AbstractBackend

SkyRL-Train backend for supervised training.

Functions:

NameDescription
has_model
build_models
init_weight_sync_stateSetup the connection between policy model and inference engine for weight syncing.
create_model
delete_model
forward_backward
forward
optim_step
sampleGenerate samples using InferenceEngineClient.
save_checkpointSave full training checkpoint (model + optimizer + scheduler) as tar.
load_checkpointLoad full training checkpoint (model + optimizer + scheduler) from tar.
save_sampler_checkpointSync weights to colocated inference engines and optionally save to disk.

Attributes:

NameTypeDescription
base_model
config
Source code in skyrl/backends/skyrl_train_backend.py:96-712
class SkyRLTrainBackend(AbstractBackend):
    """SkyRL-Train backend for supervised training."""

    def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides):
        logger.warning("=" * 80)
        logger.warning("SkyRLTrainBackend is currently EXPERIMENTAL!")
        logger.warning("=" * 80)

        if ray is None:
            raise ImportError(
                "SkyRLTrainBackend requires `ray`. Install the appropriate extras (e.g. `--extra skyrl_train`)."
            )

        self.base_model = base_model
        # NOTE: We currently have two config attributes "config" which is just config overrides and "_cfg" which is the actual config object. This is a temporary state given that the Tinker engine expects a .config attribute
        self.config = config
        self._model_id: str | None = None
        self._model_metadata: types.ModelMetadata | None = None
        self._cfg = None
        self._dispatch: WorkerDispatch | None = None
        self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model)
        self._inference_engine_client = None
        self._inference_engines_initialized = False

    def has_model(self, model_id: str) -> bool:
        return self._model_id == model_id

    def build_models(self, PolicyWorker):
        cfg = self._cfg
        colocate_all = cfg.trainer.placement.colocate_all
        pg = self._colocate_pg

        if colocate_all:
            assert pg is not None, "placement group must be created when colocate_all=True"
            num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
            num_rollout_gpus = (
                cfg.generator.inference_engine.num_engines
                * cfg.generator.inference_engine.tensor_parallel_size
                * cfg.generator.inference_engine.pipeline_parallel_size
                * cfg.generator.inference_engine.data_parallel_size
            )
            assert (
                num_policy_gpus == num_rollout_gpus
            ), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"

        policy_model = PPORayActorGroup(
            cfg.trainer,
            cfg.trainer.placement.policy_num_nodes,
            cfg.trainer.placement.policy_num_gpus_per_node,
            PolicyWorker,
            pg=pg,
            num_gpus_per_actor=0.2 if colocate_all else 1,
            colocate_all=colocate_all,
            sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
            record_memory=cfg.trainer.policy.record_memory,
        )

        # set to a large number for megatron scheduler init
        # lr will be managed externally via set_lr()
        policy_num_training_steps = 1e9
        ray.get(
            policy_model.async_init_model(
                cfg.trainer.policy.model.path,
                num_training_steps=policy_num_training_steps,
            )
        )
        ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id))

        if colocate_all:
            policy_model.offload_to_cpu()

        # Create unified dispatch that manages all actor groups
        self._dispatch = WorkerDispatch(
            cfg=cfg,
            policy_actor_group=policy_model,
            inference_engine_client=self._inference_engine_client,
        )

        # Mark all models as offloaded
        if colocate_all:
            self._dispatch.mark_all_offloaded()

        logger.info("init policy model done")

    def init_weight_sync_state(self):
        """
        Setup the connection between policy model and inference engine for weight syncing.
        """
        self._dispatch.init_weight_sync_state(self._inference_engine_client)
        logger.info("Initialized weight sync state for policy model and inference engines.")

    def _ensure_inference_engines(self):
        """Lazily create inference engines and init weight sync on first sampling-related call."""
        if self._inference_engines_initialized:
            return

        logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines")
        self._inference_engine_client = InferenceEngineClient(
            create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer),
            self._tokenizer,
            self._cfg.trainer.policy.model.path,
            self._cfg.trainer.policy.model.lora,
            self._cfg.generator.inference_engine,
        )
        self._dispatch.set_inference_engine_client(self._inference_engine_client)
        self.init_weight_sync_state()
        self._inference_engines_initialized = True

    def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
        if self._model_id is not None:
            raise ValueError(f"Model '{self._model_id}' already exists. Only one model supported.")

        # Build config
        self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config)

        if not ray.is_initialized():
            logger.info("Initializing Ray with runtime environment")
            initialize_ray(self._cfg)

        # Create shared placement group only when colocating training + inference
        if self._cfg.trainer.placement.colocate_all:
            self._colocate_pg = self._create_colocate_pg()
        else:
            self._colocate_pg = None

        # Get worker types based on strategy
        if self._cfg.trainer.strategy in ("fsdp", "fsdp2"):
            from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
        elif self._cfg.trainer.strategy == "megatron":
            from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
                PolicyWorker,
            )
        else:
            raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}")

        logger.info("Building models.")
        self.build_models(PolicyWorker)

        self._model_id = model_id
        self._model_metadata = types.ModelMetadata(adapter_index=0, lora_config=lora_config)
        logger.info(f"Created model {model_id} using RayPPOTrainer")

    def _create_colocate_pg(self):
        """Create a placement group for colocated training + inference."""
        ie_cfg = self._cfg.generator.inference_engine
        per_engine_gpu_count = ie_cfg.tensor_parallel_size * ie_cfg.pipeline_parallel_size * ie_cfg.data_parallel_size
        total_gpu_slots = ie_cfg.num_engines * per_engine_gpu_count

        logger.info(f"Creating placement group with {total_gpu_slots} GPU slots for colocated training+inference")
        pg = placement_group([{"GPU": 1, "CPU": 1}] * total_gpu_slots, strategy="PACK")

        logger.info("Waiting for placement group to be ready...")
        get_ray_pg_ready_with_timeout(pg, timeout=SKYRL_RAY_PG_TIMEOUT_IN_S)
        logger.info("Placement group ready!")

        return pg

    def delete_model(self, model_id: str) -> None:
        if self._model_id != model_id:
            raise ValueError(f"Model {model_id} not found")
        # TODO: For now, prefer shutting down the backend and re-launching. Will be improved shortly.
        raise NotImplementedError("Deleting models not yet implemented")

    def _to_training_batch(self, prepared_batch: types.PreparedModelPassBatch) -> TrainingInputBatch:
        """Convert PreparedModelPassBatch to TrainingInputBatch."""
        if not prepared_batch.all_input_ids:
            return TrainingInputBatch({})

        # SkyRL-Train shifts internally, so provide the full sequence length by
        # appending the last target token to each already-shifted input.
        full_sequences = [
            list(input_ids) + ([targets[-1]] if targets else [])
            for input_ids, targets in zip(prepared_batch.all_input_ids, prepared_batch.all_targets)
        ]

        max_seq_len = max(len(seq) for seq in full_sequences)
        max_response_len = max(len(weights) for weights in prepared_batch.all_token_weights)

        sequences, attention_masks, loss_masks, response_masks = [], [], [], []
        action_log_probs_list, advantages_list = [], []

        for seq, weights, logprobs, advs in zip(
            full_sequences,
            prepared_batch.all_token_weights,
            prepared_batch.all_sampling_logprobs,
            prepared_batch.all_advantages,
        ):
            pad_len = max_seq_len - len(seq)
            sequences.append([self._tokenizer.pad_token_id] * pad_len + list(seq))
            attention_masks.append([0] * pad_len + [1] * len(seq))
            action_pad = max_response_len - len(weights)
            loss_masks.append([0.0] * action_pad + [float(w) for w in weights])
            response_masks.append([0] * action_pad + [1] * len(weights))
            action_log_probs_list.append([0.0] * action_pad + [float(lp) for lp in logprobs])
            advantages_list.append([0.0] * action_pad + [float(a) for a in advs])

        sequences_tensor = torch.tensor(sequences, dtype=torch.long)
        attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long)
        loss_mask_tensor = torch.tensor(loss_masks, dtype=torch.float32)
        response_mask_tensor = torch.tensor(response_masks, dtype=torch.long)

        batch_dict = {
            "sequences": sequences_tensor,
            "attention_mask": attention_mask_tensor,
            "loss_mask": loss_mask_tensor,
            "response_mask": response_mask_tensor,
        }

        # Include RL fields (action_log_probs, advantages) when data is present
        has_logprobs = any(len(lp) > 0 for lp in prepared_batch.all_sampling_logprobs)
        has_advantages = any(len(a) > 0 for a in prepared_batch.all_advantages)
        if has_logprobs:
            batch_dict["action_log_probs"] = torch.tensor(action_log_probs_list, dtype=torch.float32)
        if has_advantages:
            batch_dict["advantages"] = torch.tensor(advantages_list, dtype=torch.float32)

        batch = TrainingInputBatch(batch_dict)
        batch.metadata = {"response_length": max_response_len}
        return batch

    def _pad_batch(
        self, batch: TrainingInputBatch, micro_batch_size: int | None = None
    ) -> tuple[TrainingInputBatch, int]:
        """Pad the batch so its size is divisible by the required alignment.

        The dispatch layer splits the batch evenly across DP workers, so the
        batch size must be a multiple of dp_size.  When *micro_batch_size* is
        given (needed for the Megatron backend whose ``forward_backward_func``
        doesn't support ragged micro-batches), we align to
        ``dp_size * micro_batch_size`` so each per-worker shard is also evenly
        divisible by *micro_batch_size*.

        Returns:
            (padded_batch, pad_size)
        """
        dp_size = self._dispatch.get_lcm_dp_size()
        alignment = dp_size * micro_batch_size if micro_batch_size else dp_size
        pad_size = (alignment - batch.batch_size % alignment) % alignment
        if pad_size == 0:
            return batch, 0

        new_tensors = {}
        for key, tensor in batch.items():
            if tensor is not None:
                if key == "loss_mask":
                    # Padding entries must not contribute to the loss
                    additional_dims = tensor.shape[1:]
                    padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
                else:
                    # Clone real data so shapes/dtypes are valid for the model
                    padding_tensor = tensor[torch.arange(pad_size) % tensor.shape[0]].clone()
                new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)

        padded = TrainingInputBatch(new_tensors)
        padded.metadata = batch.metadata
        logger.info(f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (alignment={alignment})")
        return padded, pad_size

    def _extract_metrics(self, data: dict) -> dict[str, float]:
        """Extract training metrics from dispatch return dict.

        Workers return metrics like 'loss', 'policy_loss', 'policy_entropy', etc.
        We convert to Tinker's colon-suffixed format (e.g. 'total_loss:sum').
        """
        metrics: dict[str, float] = {}

        # SFT path returns 'loss'; RL path returns 'final_loss' / 'policy_loss'
        if "loss" in data:
            metrics["total_loss:sum"] = float(data["loss"])
        elif "final_loss" in data:
            metrics["total_loss:sum"] = float(data["final_loss"])

        if "policy_loss" in data:
            metrics["pg_loss:sum"] = float(data["policy_loss"])
        if "policy_entropy" in data:
            metrics["entropy_loss:sum"] = float(data["policy_entropy"])
        if "response_length" in data:
            metrics["num_tokens:sum"] = float(data["response_length"])

        return metrics

    def _sleep_inference_engines(self):
        """Sleep inference engines to free GPU memory for training."""
        if self._inference_engines_initialized and self._cfg.trainer.placement.colocate_all:
            asyncio.run(self._inference_engine_client.sleep())

    def forward_backward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        if not prepared_batch.all_input_ids:
            return {}

        self._sleep_inference_engines()
        batch = self._to_training_batch(prepared_batch)
        micro_bs = (
            self._cfg.trainer.micro_train_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
        )
        batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)

        loss_fn = prepared_batch.all_loss_fns[0]
        if len(set(prepared_batch.all_loss_fns)) > 1:
            logger.warning(
                "SkyRL backend received mixed loss functions %s in one batch; using '%s' for all",
                set(prepared_batch.all_loss_fns),
                loss_fn,
            )
        loss_fn_config = next((c for c in prepared_batch.all_loss_fn_configs if c is not None), None)
        data = self._dispatch.forward_backward(
            "policy",
            batch,
            loss_fn=loss_fn,
            loss_fn_config=loss_fn_config,
        )

        # Trim padding entries from loss_fn_outputs
        if pad_size > 0 and "loss_fn_outputs" in data:
            data["loss_fn_outputs"] = data["loss_fn_outputs"][:-pad_size]

        metrics = self._extract_metrics(data)

        results = {}
        for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
            loss_fn_outputs = []
            for i in range(start_idx, end_idx):
                raw_output = data["loss_fn_outputs"][i]
                logprobs = list(raw_output.get("logprobs", []))
                elementwise_loss = list(raw_output.get("elementwise_loss", []))
                loss_fn_outputs.append(
                    {
                        "elementwise_loss": {
                            "data": elementwise_loss,
                            "dtype": "float32",
                            "shape": [len(elementwise_loss)],
                        },
                        "logprobs": {
                            "data": logprobs,
                            "dtype": "float32",
                            "shape": [len(logprobs)],
                        },
                    }
                )
            results[request_id] = types.ForwardBackwardOutput(
                loss_fn_output_type="scalar",
                loss_fn_outputs=loss_fn_outputs,
                metrics=metrics,
            )
        return results

    def forward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        if not prepared_batch.all_input_ids:
            return {}

        self._sleep_inference_engines()
        batch = self._to_training_batch(prepared_batch)
        micro_bs = (
            self._cfg.trainer.micro_forward_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
        )
        batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)
        data = self._dispatch.forward("policy", batch)

        # dispatch.forward() returns TrainingOutputBatch({"output": tensor[batch, max_response_len]})
        # Trim padding entries from output
        output_logprobs = data["output"]
        if pad_size > 0:
            output_logprobs = output_logprobs[:-pad_size]

        results = {}
        for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
            loss_fn_outputs = []
            for i in range(start_idx, end_idx):
                # Use token weights length to determine each example's actual response length
                valid_len = len(prepared_batch.all_token_weights[i])
                start = max(output_logprobs.shape[1] - valid_len, 0)
                logprobs = output_logprobs[i, start:].tolist()
                loss_fn_outputs.append(
                    {
                        "logprobs": {
                            "data": logprobs,
                            "dtype": "float32",
                            "shape": [len(logprobs)],
                        },
                    }
                )
            results[request_id] = types.ForwardBackwardOutput(
                loss_fn_output_type="scalar",
                loss_fn_outputs=loss_fn_outputs,
                metrics={},
            )
        return results

    def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
        if model_id != self._model_id:
            raise ValueError(f"Model {model_id} not found")

        # Apply learning rate from AdamParams before optimizer step
        # Note: beta1, beta2, eps are fixed at optimizer creation and cannot be changed dynamically
        adam_params = request_data.adam_params
        self._dispatch.set_lr("policy", adam_params.learning_rate)

        grad_norm = self._dispatch.optim_step("policy")
        logger.info(f"optim_step: lr={adam_params.learning_rate}, grad_norm={grad_norm}")

        metrics: dict[str, float] = {}
        if grad_norm is not None:
            metrics["skyrl.ai/grad_norm"] = float(grad_norm)
        metrics["skyrl.ai/learning_rate"] = adam_params.learning_rate
        return types.OptimStepOutput(metrics=metrics)

    def sample(
        self,
        prepared_batch: types.PreparedSampleBatch,
    ) -> dict[str, types.SampleOutput | types.ErrorResponse]:
        """Generate samples using InferenceEngineClient.

        NOTE: Weight sync is NOT triggered automatically. The caller must call
        save_weights_for_sampler() explicitly before calling sample() if weights
        have been updated.
        """
        # 1. Ensure inference engines are initialized
        self._ensure_inference_engines()

        # 2. Validate single model
        unique_models = set(prepared_batch.all_model_ids)
        if unique_models != {self._model_id}:
            error = types.ErrorResponse(
                error=f"Model mismatch. Expected {self._model_id}, got {unique_models}", status="error"
            )
            return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}

        # 3. Sample all prompts in parallel
        async def sample_all():
            tasks = []
            for i in range(len(prepared_batch.all_prompts)):
                prompt = prepared_batch.all_prompts[i]
                sampling_params = prepared_batch.all_sampling_params[i]

                # Pass through common fields; only stop needs name translation
                # (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids)
                params_dict = {
                    "temperature": sampling_params.temperature,
                    "max_tokens": sampling_params.max_tokens,
                    "seed": sampling_params.seed,
                    "top_k": sampling_params.top_k,
                    "top_p": sampling_params.top_p,
                    "logprobs": 0,
                }
                if sampling_params.stop_strings:
                    params_dict["stop"] = sampling_params.stop_strings
                if sampling_params.stop_tokens:
                    params_dict["stop_token_ids"] = sampling_params.stop_tokens

                tasks.append(
                    self._inference_engine_client.sample(
                        prompt_token_ids=prompt,
                        num_samples=1,  # Tinker batches multiple samples separately
                        sampling_params=params_dict,
                    )
                )

            return await asyncio.gather(*tasks, return_exceptions=True)

        # Backend runs in engine subprocess with no event loop
        sample_outputs = asyncio.run(sample_all())

        # Note: sample_outputs may contain Exception objects (from return_exceptions=True)
        # We preserve these to include error messages in responses

        # 4. Aggregate results by request
        return self._aggregate_sample_results(prepared_batch, sample_outputs)

    def _aggregate_sample_results(
        self,
        prepared_batch: types.PreparedSampleBatch,
        sample_outputs: list,
    ) -> dict[str, types.SampleOutput | types.ErrorResponse]:
        """Convert InferenceEngineClient outputs to Tinker format."""
        results = {}

        for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices:
            sequences = []
            has_error = False
            error_msg = None

            for i in range(start_idx, end_idx):
                output = sample_outputs[i]

                # Check if sampling failed (Exception or None)
                if isinstance(output, Exception):
                    has_error = True
                    error_msg = f"Sampling failed for sample {i}: {type(output).__name__}: {str(output)}"
                    logger.error(error_msg)
                    break
                elif output is None:
                    has_error = True
                    error_msg = f"Sampling failed for sample {i}: Unknown error (output is None)"
                    logger.error(error_msg)
                    break

                # Extract tokens and logprobs
                response_tokens = output["response_ids"][0]
                response_logprobs = (output.get("response_logprobs") or [[]])[0]
                stop_reason_raw = output["stop_reasons"][0]

                # Map vLLM stop reason to Tinker format
                stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length"

                # Ensure logprobs exist (critical for RL)
                if response_logprobs is None or len(response_logprobs) == 0:
                    logger.warning("No logprobs returned - filling with zeros")
                    response_logprobs = [0.0] * len(response_tokens)

                sequences.append(
                    types.GeneratedSequence(
                        tokens=response_tokens,
                        logprobs=response_logprobs,
                        stop_reason=stop_reason,
                    )
                )

            if has_error:
                results[request_id] = types.ErrorResponse(
                    error=error_msg or "Unknown sampling error",
                    status="error",
                )
            else:
                # Note: prompt_logprobs not supported initially
                if needs_prompt_logprobs:
                    logger.warning("Prompt logprobs requested but not yet supported")

                results[request_id] = types.SampleOutput(
                    sequences=sequences,
                    prompt_logprobs=None,
                )

        return results

    def _validate_model_state(self, model_id: str) -> None:
        """Validate that model exists and is initialized."""
        if model_id != self._model_id:
            raise ValueError(f"Model {model_id} not found")
        if self._dispatch is None:
            raise RuntimeError("Model not initialized")

    def _create_tar_from_directory(self, source_dir: str, output_path: str) -> None:
        """Create an uncompressed tar archive from a directory."""
        # Ensure parent directory exists
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        # Use uncompressed tar - gzip adds 5-10min CPU time on 6-7GB FSDP checkpoints
        with tarfile.open(output_path, "w") as tar:
            tar.add(source_dir, arcname=".")

    def save_checkpoint(self, output_path, model_id: str) -> None:
        """Save full training checkpoint (model + optimizer + scheduler) as tar."""
        self._validate_model_state(model_id)

        # Create temp directory for checkpoint
        with tempfile.TemporaryDirectory() as temp_dir:
            ckpt_dir = os.path.join(temp_dir, "checkpoint")

            # Save checkpoint directory (includes optimizer state automatically)
            self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer)

            # Create tar archive
            self._create_tar_from_directory(ckpt_dir, output_path)

        logger.info(f"Saved checkpoint for {model_id} to {output_path}")

    def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
        """Load full training checkpoint (model + optimizer + scheduler) from tar."""
        self._validate_model_state(model_id)

        # Extract tar to temp directory (filter='data' prevents path traversal attacks)
        with tempfile.TemporaryDirectory() as temp_dir:
            with tarfile.open(checkpoint_path, "r") as tar:
                tar.extractall(temp_dir, filter="data")

            # Load checkpoint (includes optimizer and scheduler states)
            self._dispatch.load_checkpoint(
                model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True
            )

        logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}")

    def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
        """Sync weights to colocated inference engines and optionally save to disk.

        The NCCL broadcast always runs so inference engines have the latest
        policy weights.  When ``persist`` is False (the common hot-path in RL
        loops) the expensive HuggingFace model export is skipped entirely.
        """
        self._validate_model_state(model_id)

        # Lazily create inference engines on first sampling-related call
        self._ensure_inference_engines()

        asyncio.run(self._dispatch.save_weights_for_sampler())
        logger.info(f"Synced weights for {model_id} to inference engines via NCCL")

        if persist:
            # TODO(tyler): For LoRA, only save the adapters instead of the full merged model
            with tempfile.TemporaryDirectory() as temp_dir:
                hf_dir = os.path.join(temp_dir, "model")
                self._dispatch.save_hf_model(model="policy", export_dir=hf_dir, tokenizer=self._tokenizer)
                self._create_tar_from_directory(hf_dir, output_path)
            logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
        else:
            # Hot path: write a lightweight marker so the engine's checkpoint
            # bookkeeping stays consistent.  Actual weights live in GPU memory.
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            with tarfile.open(output_path, "w"):
                pass  # empty tar — marker only
            logger.info(f"Synced weights for {model_id} (disk save skipped)")

attr base_model

base_model = base_model

attr config

config = config

method abstractmethod has_model

has_model(model_id: str) -> bool
Source code in skyrl/backends/skyrl_train_backend.py:120-121
    def has_model(self, model_id: str) -> bool:
        return self._model_id == model_id

method build_models

build_models(PolicyWorker)
Source code in skyrl/backends/skyrl_train_backend.py:123-178
    def build_models(self, PolicyWorker):
        cfg = self._cfg
        colocate_all = cfg.trainer.placement.colocate_all
        pg = self._colocate_pg

        if colocate_all:
            assert pg is not None, "placement group must be created when colocate_all=True"
            num_policy_gpus = cfg.trainer.placement.policy_num_gpus_per_node * cfg.trainer.placement.policy_num_nodes
            num_rollout_gpus = (
                cfg.generator.inference_engine.num_engines
                * cfg.generator.inference_engine.tensor_parallel_size
                * cfg.generator.inference_engine.pipeline_parallel_size
                * cfg.generator.inference_engine.data_parallel_size
            )
            assert (
                num_policy_gpus == num_rollout_gpus
            ), "num_policy_gpus and num_rollout_gpus must be the same when colocating all models"

        policy_model = PPORayActorGroup(
            cfg.trainer,
            cfg.trainer.placement.policy_num_nodes,
            cfg.trainer.placement.policy_num_gpus_per_node,
            PolicyWorker,
            pg=pg,
            num_gpus_per_actor=0.2 if colocate_all else 1,
            colocate_all=colocate_all,
            sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size,
            record_memory=cfg.trainer.policy.record_memory,
        )

        # set to a large number for megatron scheduler init
        # lr will be managed externally via set_lr()
        policy_num_training_steps = 1e9
        ray.get(
            policy_model.async_init_model(
                cfg.trainer.policy.model.path,
                num_training_steps=policy_num_training_steps,
            )
        )
        ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id))

        if colocate_all:
            policy_model.offload_to_cpu()

        # Create unified dispatch that manages all actor groups
        self._dispatch = WorkerDispatch(
            cfg=cfg,
            policy_actor_group=policy_model,
            inference_engine_client=self._inference_engine_client,
        )

        # Mark all models as offloaded
        if colocate_all:
            self._dispatch.mark_all_offloaded()

        logger.info("init policy model done")

method init_weight_sync_state

init_weight_sync_state()

Setup the connection between policy model and inference engine for weight syncing.

Source code in skyrl/backends/skyrl_train_backend.py:180-185
    def init_weight_sync_state(self):
        """
        Setup the connection between policy model and inference engine for weight syncing.
        """
        self._dispatch.init_weight_sync_state(self._inference_engine_client)
        logger.info("Initialized weight sync state for policy model and inference engines.")

method abstractmethod create_model

create_model(model_id: str, lora_config: types.LoraConfig) -> None
Source code in skyrl/backends/skyrl_train_backend.py:204-236
    def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
        if self._model_id is not None:
            raise ValueError(f"Model '{self._model_id}' already exists. Only one model supported.")

        # Build config
        self._cfg = _build_skyrl_train_config(self.base_model, self.config, lora_config)

        if not ray.is_initialized():
            logger.info("Initializing Ray with runtime environment")
            initialize_ray(self._cfg)

        # Create shared placement group only when colocating training + inference
        if self._cfg.trainer.placement.colocate_all:
            self._colocate_pg = self._create_colocate_pg()
        else:
            self._colocate_pg = None

        # Get worker types based on strategy
        if self._cfg.trainer.strategy in ("fsdp", "fsdp2"):
            from skyrl.backends.skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker
        elif self._cfg.trainer.strategy == "megatron":
            from skyrl.backends.skyrl_train.workers.megatron.megatron_worker import (
                PolicyWorker,
            )
        else:
            raise ValueError(f"Unknown strategy type: {self._cfg.trainer.strategy}")

        logger.info("Building models.")
        self.build_models(PolicyWorker)

        self._model_id = model_id
        self._model_metadata = types.ModelMetadata(adapter_index=0, lora_config=lora_config)
        logger.info(f"Created model {model_id} using RayPPOTrainer")

method abstractmethod delete_model

delete_model(model_id: str) -> None
Source code in skyrl/backends/skyrl_train_backend.py:253-257
    def delete_model(self, model_id: str) -> None:
        if self._model_id != model_id:
            raise ValueError(f"Model {model_id} not found")
        # TODO: For now, prefer shutting down the backend and re-launching. Will be improved shortly.
        raise NotImplementedError("Deleting models not yet implemented")

method abstractmethod forward_backward

forward_backward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]
Source code in skyrl/backends/skyrl_train_backend.py:382-443
    def forward_backward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        if not prepared_batch.all_input_ids:
            return {}

        self._sleep_inference_engines()
        batch = self._to_training_batch(prepared_batch)
        micro_bs = (
            self._cfg.trainer.micro_train_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
        )
        batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)

        loss_fn = prepared_batch.all_loss_fns[0]
        if len(set(prepared_batch.all_loss_fns)) > 1:
            logger.warning(
                "SkyRL backend received mixed loss functions %s in one batch; using '%s' for all",
                set(prepared_batch.all_loss_fns),
                loss_fn,
            )
        loss_fn_config = next((c for c in prepared_batch.all_loss_fn_configs if c is not None), None)
        data = self._dispatch.forward_backward(
            "policy",
            batch,
            loss_fn=loss_fn,
            loss_fn_config=loss_fn_config,
        )

        # Trim padding entries from loss_fn_outputs
        if pad_size > 0 and "loss_fn_outputs" in data:
            data["loss_fn_outputs"] = data["loss_fn_outputs"][:-pad_size]

        metrics = self._extract_metrics(data)

        results = {}
        for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
            loss_fn_outputs = []
            for i in range(start_idx, end_idx):
                raw_output = data["loss_fn_outputs"][i]
                logprobs = list(raw_output.get("logprobs", []))
                elementwise_loss = list(raw_output.get("elementwise_loss", []))
                loss_fn_outputs.append(
                    {
                        "elementwise_loss": {
                            "data": elementwise_loss,
                            "dtype": "float32",
                            "shape": [len(elementwise_loss)],
                        },
                        "logprobs": {
                            "data": logprobs,
                            "dtype": "float32",
                            "shape": [len(logprobs)],
                        },
                    }
                )
            results[request_id] = types.ForwardBackwardOutput(
                loss_fn_output_type="scalar",
                loss_fn_outputs=loss_fn_outputs,
                metrics=metrics,
            )
        return results

method abstractmethod forward

forward(prepared_batch: types.PreparedModelPassBatch) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]
Source code in skyrl/backends/skyrl_train_backend.py:445-488
    def forward(
        self,
        prepared_batch: types.PreparedModelPassBatch,
    ) -> dict[str, types.ForwardBackwardOutput | types.ErrorResponse]:
        if not prepared_batch.all_input_ids:
            return {}

        self._sleep_inference_engines()
        batch = self._to_training_batch(prepared_batch)
        micro_bs = (
            self._cfg.trainer.micro_forward_batch_size_per_gpu if self._cfg.trainer.strategy == "megatron" else None
        )
        batch, pad_size = self._pad_batch(batch, micro_batch_size=micro_bs)
        data = self._dispatch.forward("policy", batch)

        # dispatch.forward() returns TrainingOutputBatch({"output": tensor[batch, max_response_len]})
        # Trim padding entries from output
        output_logprobs = data["output"]
        if pad_size > 0:
            output_logprobs = output_logprobs[:-pad_size]

        results = {}
        for request_id, _, start_idx, end_idx in prepared_batch.request_batch_slices:
            loss_fn_outputs = []
            for i in range(start_idx, end_idx):
                # Use token weights length to determine each example's actual response length
                valid_len = len(prepared_batch.all_token_weights[i])
                start = max(output_logprobs.shape[1] - valid_len, 0)
                logprobs = output_logprobs[i, start:].tolist()
                loss_fn_outputs.append(
                    {
                        "logprobs": {
                            "data": logprobs,
                            "dtype": "float32",
                            "shape": [len(logprobs)],
                        },
                    }
                )
            results[request_id] = types.ForwardBackwardOutput(
                loss_fn_output_type="scalar",
                loss_fn_outputs=loss_fn_outputs,
                metrics={},
            )
        return results

method abstractmethod optim_step

optim_step(model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput
Source code in skyrl/backends/skyrl_train_backend.py:490-506
    def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput:
        if model_id != self._model_id:
            raise ValueError(f"Model {model_id} not found")

        # Apply learning rate from AdamParams before optimizer step
        # Note: beta1, beta2, eps are fixed at optimizer creation and cannot be changed dynamically
        adam_params = request_data.adam_params
        self._dispatch.set_lr("policy", adam_params.learning_rate)

        grad_norm = self._dispatch.optim_step("policy")
        logger.info(f"optim_step: lr={adam_params.learning_rate}, grad_norm={grad_norm}")

        metrics: dict[str, float] = {}
        if grad_norm is not None:
            metrics["skyrl.ai/grad_norm"] = float(grad_norm)
        metrics["skyrl.ai/learning_rate"] = adam_params.learning_rate
        return types.OptimStepOutput(metrics=metrics)

method abstractmethod sample

sample(prepared_batch: types.PreparedSampleBatch) -> dict[str, types.SampleOutput | types.ErrorResponse]

Generate samples using InferenceEngineClient.

NOTE: Weight sync is NOT triggered automatically. The caller must call save_weights_for_sampler() explicitly before calling sample() if weights have been updated.

Source code in skyrl/backends/skyrl_train_backend.py:508-568
    def sample(
        self,
        prepared_batch: types.PreparedSampleBatch,
    ) -> dict[str, types.SampleOutput | types.ErrorResponse]:
        """Generate samples using InferenceEngineClient.

        NOTE: Weight sync is NOT triggered automatically. The caller must call
        save_weights_for_sampler() explicitly before calling sample() if weights
        have been updated.
        """
        # 1. Ensure inference engines are initialized
        self._ensure_inference_engines()

        # 2. Validate single model
        unique_models = set(prepared_batch.all_model_ids)
        if unique_models != {self._model_id}:
            error = types.ErrorResponse(
                error=f"Model mismatch. Expected {self._model_id}, got {unique_models}", status="error"
            )
            return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}

        # 3. Sample all prompts in parallel
        async def sample_all():
            tasks = []
            for i in range(len(prepared_batch.all_prompts)):
                prompt = prepared_batch.all_prompts[i]
                sampling_params = prepared_batch.all_sampling_params[i]

                # Pass through common fields; only stop needs name translation
                # (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids)
                params_dict = {
                    "temperature": sampling_params.temperature,
                    "max_tokens": sampling_params.max_tokens,
                    "seed": sampling_params.seed,
                    "top_k": sampling_params.top_k,
                    "top_p": sampling_params.top_p,
                    "logprobs": 0,
                }
                if sampling_params.stop_strings:
                    params_dict["stop"] = sampling_params.stop_strings
                if sampling_params.stop_tokens:
                    params_dict["stop_token_ids"] = sampling_params.stop_tokens

                tasks.append(
                    self._inference_engine_client.sample(
                        prompt_token_ids=prompt,
                        num_samples=1,  # Tinker batches multiple samples separately
                        sampling_params=params_dict,
                    )
                )

            return await asyncio.gather(*tasks, return_exceptions=True)

        # Backend runs in engine subprocess with no event loop
        sample_outputs = asyncio.run(sample_all())

        # Note: sample_outputs may contain Exception objects (from return_exceptions=True)
        # We preserve these to include error messages in responses

        # 4. Aggregate results by request
        return self._aggregate_sample_results(prepared_batch, sample_outputs)

method abstractmethod save_checkpoint

save_checkpoint(output_path, model_id: str) -> None

Save full training checkpoint (model + optimizer + scheduler) as tar.

Source code in skyrl/backends/skyrl_train_backend.py:652-666
    def save_checkpoint(self, output_path, model_id: str) -> None:
        """Save full training checkpoint (model + optimizer + scheduler) as tar."""
        self._validate_model_state(model_id)

        # Create temp directory for checkpoint
        with tempfile.TemporaryDirectory() as temp_dir:
            ckpt_dir = os.path.join(temp_dir, "checkpoint")

            # Save checkpoint directory (includes optimizer state automatically)
            self._dispatch.save_checkpoint(model="policy", ckpt_dir=ckpt_dir, tokenizer=self._tokenizer)

            # Create tar archive
            self._create_tar_from_directory(ckpt_dir, output_path)

        logger.info(f"Saved checkpoint for {model_id} to {output_path}")

method abstractmethod load_checkpoint

load_checkpoint(checkpoint_path, model_id: str) -> None

Load full training checkpoint (model + optimizer + scheduler) from tar.

Source code in skyrl/backends/skyrl_train_backend.py:668-682
    def load_checkpoint(self, checkpoint_path, model_id: str) -> None:
        """Load full training checkpoint (model + optimizer + scheduler) from tar."""
        self._validate_model_state(model_id)

        # Extract tar to temp directory (filter='data' prevents path traversal attacks)
        with tempfile.TemporaryDirectory() as temp_dir:
            with tarfile.open(checkpoint_path, "r") as tar:
                tar.extractall(temp_dir, filter="data")

            # Load checkpoint (includes optimizer and scheduler states)
            self._dispatch.load_checkpoint(
                model="policy", ckpt_dir=temp_dir, load_optimizer_states=True, load_lr_scheduler_states=True
            )

        logger.info(f"Loaded checkpoint for {model_id} from {checkpoint_path}")

method abstractmethod save_sampler_checkpoint

save_sampler_checkpoint(output_path, model_id: str, persist: bool = True) -> None

Sync weights to colocated inference engines and optionally save to disk.

The NCCL broadcast always runs so inference engines have the latest policy weights. When persist is False (the common hot-path in RL loops) the expensive HuggingFace model export is skipped entirely.

Source code in skyrl/backends/skyrl_train_backend.py:684-712
    def save_sampler_checkpoint(self, output_path, model_id: str, persist: bool = True) -> None:
        """Sync weights to colocated inference engines and optionally save to disk.

        The NCCL broadcast always runs so inference engines have the latest
        policy weights.  When ``persist`` is False (the common hot-path in RL
        loops) the expensive HuggingFace model export is skipped entirely.
        """
        self._validate_model_state(model_id)

        # Lazily create inference engines on first sampling-related call
        self._ensure_inference_engines()

        asyncio.run(self._dispatch.save_weights_for_sampler())
        logger.info(f"Synced weights for {model_id} to inference engines via NCCL")

        if persist:
            # TODO(tyler): For LoRA, only save the adapters instead of the full merged model
            with tempfile.TemporaryDirectory() as temp_dir:
                hf_dir = os.path.join(temp_dir, "model")
                self._dispatch.save_hf_model(model="policy", export_dir=hf_dir, tokenizer=self._tokenizer)
                self._create_tar_from_directory(hf_dir, output_path)
            logger.info(f"Saved sampler checkpoint for {model_id} to {output_path}")
        else:
            # Hot path: write a lightweight marker so the engine's checkpoint
            # bookkeeping stays consistent.  Actual weights live in GPU memory.
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            with tarfile.open(output_path, "w"):
                pass  # empty tar — marker only
            logger.info(f"Synced weights for {model_id} (disk save skipped)")

On this page