SkyRL
API ReferenceSkyRL

Tinker Engine

Orchestration engine for RL training.

Engine

The TinkerEngine orchestrates backends, models, and training loops.

class TinkerEngine

TinkerEngine(config: EngineConfig)

Background engine for processing training requests.

The engine handles:

  • Database operations (futures, checkpoints)
  • Request finding/scheduling
  • File I/O (download/upload checkpoints)
  • Validating requests against loaded models

Computation and model management are delegated to the backend.

Functions:

NameDescription
find_batchable_model_passesFind all requests of the given type that come before any destructive update for their model.
find_batchable_sampleFind all sample ops that can be safely batched together.
find_single_requestsFind all requests that need to be processed individually (not batchable).
process_create_modelCreate and initialize a model.
process_unload_modelUnload a model and free all resources.
cleanup_stale_sessionsCleanup sessions with no recent heartbeat and unload their models.
process_optim_stepProcess an optim_step request and apply accumulated gradients.
process_forward_backwardRun forward and backward pass on a batch of requests.
process_forwardRun forward-only pass on a batch of requests.
process_sampleGenerate samples for a batch of requests.
process_load_weightsLoads a clean, trimmed training checkpoint.
process_save_weightsSaves a clean training checkpoint by converting the trimmed NNX graph
process_save_weights_for_samplerProcess a save_weights_for_sampler request and save model weights.
process_single_request
process_single_requestsProcess a collection of single (non-batchable) requests.
process_batch_requestsProcess a batch of requests with error handling and future completion.
process_pending_requestsMain loop to process pending requests.
runEntry point to start the engine.

Attributes:

NameTypeDescription
config
db_engine
backend
metricsEngineMetricsPass-through to backend metrics for backwards compatibility.

Initialize the engine with a database connection and base model.

Source code in skyrl/tinker/engine.py:187-718
class TinkerEngine:
    """Background engine for processing training requests.

    The engine handles:
    - Database operations (futures, checkpoints)
    - Request finding/scheduling
    - File I/O (download/upload checkpoints)
    - Validating requests against loaded models

    Computation and model management are delegated to the backend.
    """

    def _filter_valid_requests(
        self,
        requests: dict[str, tuple[str, BaseModel]],
    ) -> tuple[dict[str, types.ErrorResponse], dict[str, tuple[str, BaseModel]]]:
        """Filter out requests with invalid model_ids and return error results for them.

        Args:
            requests: Dict mapping request_id to (model_id, request_data) tuples

        Returns:
            Tuple of (error_results, valid_requests)
        """
        results = {}
        valid_requests = {}

        for request_id, (model_id, request_data) in requests.items():
            error = None
            if model_id and not self.backend.has_model(model_id):
                error = f"Model {model_id} not loaded"
            elif not model_id and isinstance(request_data, types.SampleInput):
                if request_data.base_model != self.config.base_model:
                    error = f"Engine is configured for '{self.config.base_model}' but request specified '{request_data.base_model}'"
                elif request_data.checkpoint_id:
                    error = "checkpoint_id must be empty for base model sampling"

            if error:
                results[request_id] = types.ErrorResponse(error=error, status="failed")
            else:
                valid_requests[request_id] = (model_id, request_data)

        return results, valid_requests

    def __init__(
        self,
        config: EngineConfig,
    ):
        """Initialize the engine with a database connection and base model."""
        self.config = config
        self.db_engine = create_engine(config.database_url, echo=False)
        enable_sqlite_wal(self.db_engine)

        # Initialize the backend (handles model state, computation, and adapter management)
        backend_class, backend_config_class = get_backend_classes(config.backend)
        backend_config = backend_config_class(**config.backend_config)
        self.backend = backend_class(config.base_model, backend_config)

        # Track last cleanup time for periodic stale session cleanup
        self._last_cleanup_time: float = time.time()

        logger.info(f"Initialized TinkerEngine with backend={type(self.backend).__name__}")

    @property
    def metrics(self) -> types.EngineMetrics:
        """Pass-through to backend metrics for backwards compatibility."""
        return self.backend.metrics

    @contextmanager
    def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType):
        """Context manager to handle checkpoint DB status updates.

        Fetches the checkpoint entry, yields it, and updates its status to COMPLETED
        or FAILED based on whether an exception occurred.
        """
        with Session(self.db_engine) as session:
            # Fail fast if API didn't create the checkpoint row first.
            if session.get(CheckpointDB, (model_id, checkpoint_id, checkpoint_type)) is None:
                raise ValueError(
                    f"Checkpoint entry not found for model '{model_id}', checkpoint '{checkpoint_id}', type '{checkpoint_type}'"
                )

        status = CheckpointStatus.FAILED
        error_message = "checkpoint operation interrupted"
        try:
            # Run potentially slow checkpoint I/O without an open DB transaction.
            yield
            status = CheckpointStatus.COMPLETED
            error_message = None
        except Exception as e:
            logger.exception(f"Error saving checkpoint for model {model_id}, checkpoint {checkpoint_id}: {e}")
            error_message = str(e)
            raise
        finally:
            # Persist final status in a short write transaction.
            with Session(self.db_engine) as session:
                result = session.exec(
                    update(CheckpointDB)
                    .where(CheckpointDB.model_id == model_id)
                    .where(CheckpointDB.checkpoint_id == checkpoint_id)
                    .where(CheckpointDB.checkpoint_type == checkpoint_type)
                    .values(
                        status=status,
                        error_message=error_message,
                        completed_at=datetime.now(timezone.utc),
                    )
                )
                if not result.rowcount:
                    logger.warning(
                        f"Checkpoint row disappeared before status update: "
                        f"model_id={model_id}, checkpoint_id={checkpoint_id}, checkpoint_type={checkpoint_type}"
                    )
                session.commit()

    def find_batchable_model_passes(
        self, session: Session, request_type: types.RequestType
    ) -> dict[str, tuple[str, types.ForwardBackwardInput]]:
        """Find all requests of the given type that come before any destructive update for their model.

        Uses look-ahead scheduling: for each model, only returns operations
        that have no optim_step or load_weights blocking them in the queue.

        Args:
            session: Database session
            request_type: The type of request to find (e.g., FORWARD or FORWARD_BACKWARD)

        Returns:
            Dict mapping request_id to (model_id, request_data) tuples
        """
        # Find the earliest pending optim_step or load_weights per model (these act as barriers)
        barriers_query = (
            select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id"))
            .where(
                (FutureDB.request_type == types.RequestType.OPTIM_STEP)
                | (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS)
            )
            .where(FutureDB.status == RequestStatus.PENDING)
            .group_by(FutureDB.model_id)
        )
        barriers = dict(session.exec(barriers_query).all())

        # Get all pending operations of the requested type ordered by request_id
        query = (
            select(FutureDB)
            .where(FutureDB.request_type == request_type)
            .where(FutureDB.status == RequestStatus.PENDING)
            .order_by(FutureDB.request_id)
        )
        ops = session.exec(query).all()

        # Filter: only include ops that come before their model's barrier
        batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]]

        return {
            str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data))
            for f in batchable
        }

    def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types.SampleInput]]:
        """Find all sample ops that can be safely batched together.

        Returns sample operations ensuring that each model_id has only one checkpoint_id
        to avoid loading different checkpoints for the same model in a single batch.

        If sample_max_num_sequences is configured, limits to that many requests so we don't
        produce partial batches in process_sample_batch. If num_samples > 1 for some requests,
        this may not be perfect, but it's good until we implement continuous batching.

        Args:
            session: Database session

        Returns:
            Dict mapping request_id to (model_id, request_data) tuples
        """
        sample_query = (
            select(FutureDB)
            .where(FutureDB.request_type == types.RequestType.SAMPLE)
            .where(FutureDB.status == RequestStatus.PENDING)
            .order_by(FutureDB.request_id)
        )
        sample_ops = session.exec(sample_query).all()

        batchable = []
        model_checkpoints = {}  # Map from model_id to checkpoint_id of first request to that model
        for op in sample_ops:
            checkpoint_id = op.request_data["checkpoint_id"]
            # Base model requests (empty checkpoint_id) are always compatible, otherwise only
            # take only requests with one checkpoint_id for a given model_id
            if checkpoint_id == "" or model_checkpoints.setdefault(op.model_id, checkpoint_id) == checkpoint_id:
                batchable.append(op)

        # TODO: This leaks the abstraction by accessing backend-specific config.
        # We should find a better way to handle this going forward.
        if self.config.backend == "jax" and self.backend.config.sample_max_num_sequences > 0:
            batchable = batchable[: self.backend.config.sample_max_num_sequences]

        return {str(f.request_id): (f.model_id, types.SampleInput.model_validate(f.request_data)) for f in batchable}

    def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.RequestType, dict]]:
        """Find all requests that need to be processed individually (not batchable).

        Args:
            session: Database session

        Returns:
            Dict mapping request_id to (model_id, request_type, request_data) tuples
        """
        statement = (
            select(FutureDB)
            .where(FutureDB.status == RequestStatus.PENDING)
            .where(FutureDB.request_type != types.RequestType.FORWARD_BACKWARD)
            .where(FutureDB.request_type != types.RequestType.FORWARD)
            .where(FutureDB.request_type != types.RequestType.SAMPLE)
            .where(FutureDB.request_type != types.RequestType.EXTERNAL)
            .order_by(FutureDB.request_id)
        )
        other_futures = session.exec(statement).all()

        return {str(f.request_id): (f.model_id, f.request_type, f.request_data) for f in other_futures}

    def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput:
        """Create and initialize a model."""
        # Create model in backend (allocates adapter_index, creates optimizer, and configures adapter)
        self.backend.create_model(model_id, request_data.lora_config)

        logger.info(f"Created LoRA model {model_id}")

        return types.CreateModelOutput(
            model_id=model_id,
            base_model=self.config.base_model,
            lora_config=request_data.lora_config,
        )

    def process_unload_model(self, model_id: str, request_data: types.UnloadModelInput) -> types.UnloadModelOutput:
        """Unload a model and free all resources."""
        if not self.backend.has_model(model_id):
            logger.warning(f"Ignoring unload request for model {model_id} that is not loaded.")
        else:
            self.backend.delete_model(model_id)

            # Update model status in DB
            with Session(self.db_engine) as session:
                _ = session.exec(update(ModelDB).where(ModelDB.model_id == model_id).values(status="unloaded"))
                session.commit()

            logger.info(f"Unloaded model {model_id}")

        return types.UnloadModelOutput(model_id=model_id, status="unloaded")

    def cleanup_stale_sessions(self) -> int:
        """Cleanup sessions with no recent heartbeat and unload their models.

        Returns:
            Number of models unloaded
        """
        cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.config.session_timeout_sec)
        unloaded_count = 0

        with Session(self.db_engine) as session:
            # Find stale sessions (active sessions with heartbeat older than cutoff)
            stale_sessions = session.exec(
                select(SessionDB).where(
                    SessionDB.status == "active",
                    SessionDB.last_heartbeat_at < cutoff,
                )
            ).all()

            if not stale_sessions:
                return 0

            stale_session_ids = {s.session_id for s in stale_sessions}

            # Find all models for all stale sessions in one query
            models_to_process = session.exec(
                select(ModelDB).where(
                    ModelDB.session_id.in_(stale_session_ids),
                    ModelDB.status != "unloaded",
                )
            ).all()

        # Unload models outside DB transactions to minimize lock time.
        sessions_with_failed_unloads: set[str] = set()
        unloaded_model_ids: set[str] = set()
        for model in models_to_process:
            if self.backend.has_model(model.model_id):
                try:
                    self.backend.delete_model(model.model_id)
                    unloaded_model_ids.add(model.model_id)
                    unloaded_count += 1
                    logger.info(f"Auto-unloaded stale model {model.model_id} from session {model.session_id}")
                except Exception as e:
                    logger.error(f"Failed to auto-unload model {model.model_id}: {e}")
                    sessions_with_failed_unloads.add(model.session_id)
            else:
                # Model already missing in backend; only DB state needs cleanup.
                unloaded_model_ids.add(model.model_id)

        sessions_to_expire = [s.session_id for s in stale_sessions if s.session_id not in sessions_with_failed_unloads]

        # Apply DB status updates in one short write transaction.
        with Session(self.db_engine) as session:
            if unloaded_model_ids:
                _ = session.exec(
                    update(ModelDB).where(ModelDB.model_id.in_(unloaded_model_ids)).values(status="unloaded")
                )
            if sessions_to_expire:
                _ = session.exec(
                    update(SessionDB).where(SessionDB.session_id.in_(sessions_to_expire)).values(status="expired")
                )
            session.commit()

        for session_id in sessions_to_expire:
            logger.info(f"Expired stale session {session_id}")

        return unloaded_count

    def process_optim_step(
        self, model_id: str, request_data: types.OptimStepInput
    ) -> types.OptimStepOutput | types.ErrorResponse:
        """Process an optim_step request and apply accumulated gradients."""
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        return self.backend.optim_step(model_id, request_data)

    def process_forward_backward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict:
        """Run forward and backward pass on a batch of requests."""
        prepared = prepare_model_pass_batch(requests)
        return self.backend.forward_backward(prepared)

    def process_forward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict:
        """Run forward-only pass on a batch of requests."""
        prepared = prepare_model_pass_batch(requests)
        return self.backend.forward(prepared)

    def process_sample(self, requests: dict[str, tuple[str, types.SampleInput]]) -> dict:
        """Generate samples for a batch of requests."""
        prepared = prepare_sample_batch(requests, self.config.checkpoints_base)
        return self.backend.sample(prepared)

    def process_load_weights(
        self, model_id: str, request_data: types.LoadWeightsInput
    ) -> types.LoadWeightsOutput | types.ErrorResponse:
        """Loads a clean, trimmed training checkpoint."""
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        checkpoint_path = (
            self.config.checkpoints_base / request_data.source_model_id / f"{request_data.checkpoint_id}.tar.gz"
        )

        self.backend.load_checkpoint(checkpoint_path, model_id)

        return types.LoadWeightsOutput(type="load_weights")

    def process_save_weights(
        self, model_id: str, request_data: types.SaveWeightsInput
    ) -> types.SaveWeightsOutput | types.ErrorResponse:
        """
        Saves a clean training checkpoint by converting the trimmed NNX graph
        to a pure dictionary before serialization, following official Flax docs.
        """
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        checkpoint_id = request_data.path
        output_path = self.config.checkpoints_base / model_id / f"{checkpoint_id}.tar.gz"

        with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.TRAINING):
            self.backend.save_checkpoint(output_path, model_id)
            logger.info(f"Saved trimmed training checkpoint for model {model_id} to {output_path}")

        return types.SaveWeightsOutput(
            path=f"tinker://{model_id}/weights/{checkpoint_id}",
            type="save_weights",
        )

    def process_save_weights_for_sampler(
        self, model_id: str, request_data: types.SaveWeightsForSamplerInput
    ) -> types.SaveWeightsForSamplerOutput | types.ErrorResponse:
        """Process a save_weights_for_sampler request and save model weights."""
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        # Make sure the user cannot store checkpoints in places like ../../<important file>
        checkpoint_id = Path(request_data.path).name
        output_path = self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz"

        # When the caller provides a sampling_session_seq_id the save is
        # transient — weights only need to reach the inference engines, not
        # disk.  Backends can skip the expensive write in that case.
        persist = request_data.sampling_session_seq_id is None

        with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER):
            self.backend.save_sampler_checkpoint(output_path, model_id, persist=persist)
            logger.info(f"Saved sampler checkpoint for model {model_id} to {output_path}")

        # Return path=None when using sampling_session_seq_id and seq_id (SDK expects this)
        if request_data.sampling_session_seq_id is not None and request_data.seq_id is not None:
            output_path_str = None
        else:
            output_path_str = f"tinker://{model_id}/{checkpoint_id}"

        return types.SaveWeightsForSamplerOutput(
            path=output_path_str,
            type="save_weights_for_sampler",
            sampling_session_id=request_data.sampling_session_id,
        )

    def _complete_futures(self, results: dict[str, BaseModel]):
        """Helper method to complete multiple futures in the database.

        Args:
            results: Dict mapping request_id to result (Pydantic BaseModel)
        """
        completed_at = datetime.now(timezone.utc)
        params = [
            {
                "request_id": int(request_id),
                "result_data": result.model_dump(),
                "status": RequestStatus.FAILED if isinstance(result, types.ErrorResponse) else RequestStatus.COMPLETED,
                "completed_at": completed_at,
            }
            for request_id, result in results.items()
        ]

        with Session(self.db_engine) as session:
            session.execute(update(FutureDB), params)
            session.commit()

    def process_single_request(self, request_type: types.RequestType, model_id: str, request_data: dict) -> BaseModel:
        match request_type:
            case types.RequestType.CREATE_MODEL:
                return self.process_create_model(model_id, types.CreateModelInput.model_validate(request_data))
            case types.RequestType.OPTIM_STEP:
                return self.process_optim_step(model_id, types.OptimStepInput.model_validate(request_data))
            case types.RequestType.SAVE_WEIGHTS_FOR_SAMPLER:
                return self.process_save_weights_for_sampler(
                    model_id, types.SaveWeightsForSamplerInput.model_validate(request_data)
                )
            case types.RequestType.SAVE_WEIGHTS:
                return self.process_save_weights(model_id, types.SaveWeightsInput.model_validate(request_data))
            case types.RequestType.LOAD_WEIGHTS:
                return self.process_load_weights(model_id, types.LoadWeightsInput.model_validate(request_data))
            case types.RequestType.UNLOAD_MODEL:
                return self.process_unload_model(model_id, types.UnloadModelInput.model_validate(request_data))
            case _:
                raise ValueError(f"Unknown request type: {request_type}")

    def process_single_requests(self, requests: dict[str, tuple[str, types.RequestType, dict]]):
        """Process a collection of single (non-batchable) requests.

        Args:
            requests: Dict mapping request_id to (model_id, request_type, request_data) tuples
        """
        if not requests:
            return
        results = {}
        for request_id, (model_id, request_type, request_data) in requests.items():
            with log_timing(f"process_single_request({request_type.value})"):
                try:
                    result = self.process_single_request(request_type, model_id, request_data)
                except Exception as e:
                    logger.exception(f"Error processing request {request_id}: {e}")
                    result = types.ErrorResponse(error=str(e), status="failed")
            results[request_id] = result
        self._complete_futures(results)

    def process_batch_requests(
        self,
        requests: dict[str, tuple[str, BaseModel]],
        processor: Callable[[dict[str, tuple[str, BaseModel]]], dict[str, BaseModel]],
        name: str,
    ):
        """Process a batch of requests with error handling and future completion.

        Args:
            requests: Dict mapping request_id to (model_id, request_data) tuples
            processor: Function that processes requests and returns results dict
            name: Name for logging
        """
        if not requests:
            return
        with log_timing(f"process_batch_requests({name}, n={len(requests)})"):
            try:
                error_results, valid_requests = self._filter_valid_requests(requests)
                if valid_requests:
                    results = processor(valid_requests)
                    results.update(error_results)
                else:
                    results = error_results
            except Exception as e:
                logger.exception(f"Error processing batch: {e}")
                results = {request_id: types.ErrorResponse(error=str(e), status="failed") for request_id in requests}
        self._complete_futures(results)

    def process_pending_requests(self):
        """Main loop to process pending requests."""
        while True:
            # Query for pending requests and extract data within session context
            with Session(self.db_engine) as session:
                # Use look-ahead scheduling to find batchable forward_backward and forward model passes
                forward_backward_requests = self.find_batchable_model_passes(
                    session, types.RequestType.FORWARD_BACKWARD
                )
                forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD)
                # Find pending sample requests that can be batched
                sample_requests = self.find_batchable_sample(session)
                # Get other pending requests (non forward_backward and non sampling)
                other_requests = self.find_single_requests(session)

            # Process batches outside of session context
            self.process_batch_requests(forward_backward_requests, self.process_forward_backward, "forward_backward")
            self.process_batch_requests(forward_requests, self.process_forward, "forward")
            self.process_batch_requests(sample_requests, self.process_sample, "sample")

            # Process other request types individually (in the future we can also batch independent optim_steps)
            self.process_single_requests(other_requests)

            # Periodically cleanup stale sessions (disabled if either config is negative)
            cleanup_enabled = self.config.session_cleanup_interval_sec >= 0 and self.config.session_timeout_sec >= 0
            if cleanup_enabled and time.time() - self._last_cleanup_time > self.config.session_cleanup_interval_sec:
                _ = self.cleanup_stale_sessions()
                self._last_cleanup_time = time.time()

            # Poll every 100ms
            time.sleep(0.1)

    def run(self):
        """Entry point to start the engine."""
        logger.info("Starting background engine...")
        self.process_pending_requests()

attr config

config = config

attr db_engine

db_engine = create_engine(config.database_url, echo=False)

attr backend

backend = backend_class(config.base_model, backend_config)

attr property metrics

metrics: types.EngineMetrics

Pass-through to backend metrics for backwards compatibility.

method find_batchable_model_passes

find_batchable_model_passes(session: Session, request_type: types.RequestType) -> dict[str, tuple[str, types.ForwardBackwardInput]]

Find all requests of the given type that come before any destructive update for their model.

Uses look-ahead scheduling: for each model, only returns operations that have no optim_step or load_weights blocking them in the queue.

Parameters:

NameTypeDescriptionDefault
sessionSessionDatabase sessionrequired
request_typeRequestTypeThe type of request to find (e.g., FORWARD or FORWARD_BACKWARD)required

Returns:

TypeDescription
dict[str, tuple[str, ForwardBackwardInput]]Dict mapping request_id to (model_id, request_data) tuples
Source code in skyrl/tinker/engine.py:301-343
    def find_batchable_model_passes(
        self, session: Session, request_type: types.RequestType
    ) -> dict[str, tuple[str, types.ForwardBackwardInput]]:
        """Find all requests of the given type that come before any destructive update for their model.

        Uses look-ahead scheduling: for each model, only returns operations
        that have no optim_step or load_weights blocking them in the queue.

        Args:
            session: Database session
            request_type: The type of request to find (e.g., FORWARD or FORWARD_BACKWARD)

        Returns:
            Dict mapping request_id to (model_id, request_data) tuples
        """
        # Find the earliest pending optim_step or load_weights per model (these act as barriers)
        barriers_query = (
            select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id"))
            .where(
                (FutureDB.request_type == types.RequestType.OPTIM_STEP)
                | (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS)
            )
            .where(FutureDB.status == RequestStatus.PENDING)
            .group_by(FutureDB.model_id)
        )
        barriers = dict(session.exec(barriers_query).all())

        # Get all pending operations of the requested type ordered by request_id
        query = (
            select(FutureDB)
            .where(FutureDB.request_type == request_type)
            .where(FutureDB.status == RequestStatus.PENDING)
            .order_by(FutureDB.request_id)
        )
        ops = session.exec(query).all()

        # Filter: only include ops that come before their model's barrier
        batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]]

        return {
            str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data))
            for f in batchable
        }

method find_batchable_sample

find_batchable_sample(session: Session) -> dict[str, tuple[str, types.SampleInput]]

Find all sample ops that can be safely batched together.

Returns sample operations ensuring that each model_id has only one checkpoint_id to avoid loading different checkpoints for the same model in a single batch.

If sample_max_num_sequences is configured, limits to that many requests so we don't produce partial batches in process_sample_batch. If num_samples > 1 for some requests, this may not be perfect, but it's good until we implement continuous batching.

Parameters:

NameTypeDescriptionDefault
sessionSessionDatabase sessionrequired

Returns:

TypeDescription
dict[str, tuple[str, SampleInput]]Dict mapping request_id to (model_id, request_data) tuples
Source code in skyrl/tinker/engine.py:345-383
    def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types.SampleInput]]:
        """Find all sample ops that can be safely batched together.

        Returns sample operations ensuring that each model_id has only one checkpoint_id
        to avoid loading different checkpoints for the same model in a single batch.

        If sample_max_num_sequences is configured, limits to that many requests so we don't
        produce partial batches in process_sample_batch. If num_samples > 1 for some requests,
        this may not be perfect, but it's good until we implement continuous batching.

        Args:
            session: Database session

        Returns:
            Dict mapping request_id to (model_id, request_data) tuples
        """
        sample_query = (
            select(FutureDB)
            .where(FutureDB.request_type == types.RequestType.SAMPLE)
            .where(FutureDB.status == RequestStatus.PENDING)
            .order_by(FutureDB.request_id)
        )
        sample_ops = session.exec(sample_query).all()

        batchable = []
        model_checkpoints = {}  # Map from model_id to checkpoint_id of first request to that model
        for op in sample_ops:
            checkpoint_id = op.request_data["checkpoint_id"]
            # Base model requests (empty checkpoint_id) are always compatible, otherwise only
            # take only requests with one checkpoint_id for a given model_id
            if checkpoint_id == "" or model_checkpoints.setdefault(op.model_id, checkpoint_id) == checkpoint_id:
                batchable.append(op)

        # TODO: This leaks the abstraction by accessing backend-specific config.
        # We should find a better way to handle this going forward.
        if self.config.backend == "jax" and self.backend.config.sample_max_num_sequences > 0:
            batchable = batchable[: self.backend.config.sample_max_num_sequences]

        return {str(f.request_id): (f.model_id, types.SampleInput.model_validate(f.request_data)) for f in batchable}

method find_single_requests

find_single_requests(session: Session) -> dict[str, tuple[str, types.RequestType, dict]]

Find all requests that need to be processed individually (not batchable).

Parameters:

NameTypeDescriptionDefault
sessionSessionDatabase sessionrequired

Returns:

TypeDescription
dict[str, tuple[str, RequestType, dict]]Dict mapping request_id to (model_id, request_type, request_data) tuples
Source code in skyrl/tinker/engine.py:385-405
    def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.RequestType, dict]]:
        """Find all requests that need to be processed individually (not batchable).

        Args:
            session: Database session

        Returns:
            Dict mapping request_id to (model_id, request_type, request_data) tuples
        """
        statement = (
            select(FutureDB)
            .where(FutureDB.status == RequestStatus.PENDING)
            .where(FutureDB.request_type != types.RequestType.FORWARD_BACKWARD)
            .where(FutureDB.request_type != types.RequestType.FORWARD)
            .where(FutureDB.request_type != types.RequestType.SAMPLE)
            .where(FutureDB.request_type != types.RequestType.EXTERNAL)
            .order_by(FutureDB.request_id)
        )
        other_futures = session.exec(statement).all()

        return {str(f.request_id): (f.model_id, f.request_type, f.request_data) for f in other_futures}

method process_create_model

process_create_model(model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput

Create and initialize a model.

Source code in skyrl/tinker/engine.py:407-418
    def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput:
        """Create and initialize a model."""
        # Create model in backend (allocates adapter_index, creates optimizer, and configures adapter)
        self.backend.create_model(model_id, request_data.lora_config)

        logger.info(f"Created LoRA model {model_id}")

        return types.CreateModelOutput(
            model_id=model_id,
            base_model=self.config.base_model,
            lora_config=request_data.lora_config,
        )

method process_unload_model

process_unload_model(model_id: str, request_data: types.UnloadModelInput) -> types.UnloadModelOutput

Unload a model and free all resources.

Source code in skyrl/tinker/engine.py:420-434
    def process_unload_model(self, model_id: str, request_data: types.UnloadModelInput) -> types.UnloadModelOutput:
        """Unload a model and free all resources."""
        if not self.backend.has_model(model_id):
            logger.warning(f"Ignoring unload request for model {model_id} that is not loaded.")
        else:
            self.backend.delete_model(model_id)

            # Update model status in DB
            with Session(self.db_engine) as session:
                _ = session.exec(update(ModelDB).where(ModelDB.model_id == model_id).values(status="unloaded"))
                session.commit()

            logger.info(f"Unloaded model {model_id}")

        return types.UnloadModelOutput(model_id=model_id, status="unloaded")

method cleanup_stale_sessions

cleanup_stale_sessions() -> int

Cleanup sessions with no recent heartbeat and unload their models.

Returns:

TypeDescription
intNumber of models unloaded
Source code in skyrl/tinker/engine.py:436-501
    def cleanup_stale_sessions(self) -> int:
        """Cleanup sessions with no recent heartbeat and unload their models.

        Returns:
            Number of models unloaded
        """
        cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.config.session_timeout_sec)
        unloaded_count = 0

        with Session(self.db_engine) as session:
            # Find stale sessions (active sessions with heartbeat older than cutoff)
            stale_sessions = session.exec(
                select(SessionDB).where(
                    SessionDB.status == "active",
                    SessionDB.last_heartbeat_at < cutoff,
                )
            ).all()

            if not stale_sessions:
                return 0

            stale_session_ids = {s.session_id for s in stale_sessions}

            # Find all models for all stale sessions in one query
            models_to_process = session.exec(
                select(ModelDB).where(
                    ModelDB.session_id.in_(stale_session_ids),
                    ModelDB.status != "unloaded",
                )
            ).all()

        # Unload models outside DB transactions to minimize lock time.
        sessions_with_failed_unloads: set[str] = set()
        unloaded_model_ids: set[str] = set()
        for model in models_to_process:
            if self.backend.has_model(model.model_id):
                try:
                    self.backend.delete_model(model.model_id)
                    unloaded_model_ids.add(model.model_id)
                    unloaded_count += 1
                    logger.info(f"Auto-unloaded stale model {model.model_id} from session {model.session_id}")
                except Exception as e:
                    logger.error(f"Failed to auto-unload model {model.model_id}: {e}")
                    sessions_with_failed_unloads.add(model.session_id)
            else:
                # Model already missing in backend; only DB state needs cleanup.
                unloaded_model_ids.add(model.model_id)

        sessions_to_expire = [s.session_id for s in stale_sessions if s.session_id not in sessions_with_failed_unloads]

        # Apply DB status updates in one short write transaction.
        with Session(self.db_engine) as session:
            if unloaded_model_ids:
                _ = session.exec(
                    update(ModelDB).where(ModelDB.model_id.in_(unloaded_model_ids)).values(status="unloaded")
                )
            if sessions_to_expire:
                _ = session.exec(
                    update(SessionDB).where(SessionDB.session_id.in_(sessions_to_expire)).values(status="expired")
                )
            session.commit()

        for session_id in sessions_to_expire:
            logger.info(f"Expired stale session {session_id}")

        return unloaded_count

method process_optim_step

process_optim_step(model_id: str, request_data: types.OptimStepInput) -> types.OptimStepOutput | types.ErrorResponse

Process an optim_step request and apply accumulated gradients.

Source code in skyrl/tinker/engine.py:503-510
    def process_optim_step(
        self, model_id: str, request_data: types.OptimStepInput
    ) -> types.OptimStepOutput | types.ErrorResponse:
        """Process an optim_step request and apply accumulated gradients."""
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        return self.backend.optim_step(model_id, request_data)

method process_forward_backward

process_forward_backward(requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict

Run forward and backward pass on a batch of requests.

Source code in skyrl/tinker/engine.py:512-515
    def process_forward_backward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict:
        """Run forward and backward pass on a batch of requests."""
        prepared = prepare_model_pass_batch(requests)
        return self.backend.forward_backward(prepared)

method process_forward

process_forward(requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict

Run forward-only pass on a batch of requests.

Source code in skyrl/tinker/engine.py:517-520
    def process_forward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict:
        """Run forward-only pass on a batch of requests."""
        prepared = prepare_model_pass_batch(requests)
        return self.backend.forward(prepared)

method process_sample

process_sample(requests: dict[str, tuple[str, types.SampleInput]]) -> dict

Generate samples for a batch of requests.

Source code in skyrl/tinker/engine.py:522-525
    def process_sample(self, requests: dict[str, tuple[str, types.SampleInput]]) -> dict:
        """Generate samples for a batch of requests."""
        prepared = prepare_sample_batch(requests, self.config.checkpoints_base)
        return self.backend.sample(prepared)

method process_load_weights

process_load_weights(model_id: str, request_data: types.LoadWeightsInput) -> types.LoadWeightsOutput | types.ErrorResponse

Loads a clean, trimmed training checkpoint.

Source code in skyrl/tinker/engine.py:527-540
    def process_load_weights(
        self, model_id: str, request_data: types.LoadWeightsInput
    ) -> types.LoadWeightsOutput | types.ErrorResponse:
        """Loads a clean, trimmed training checkpoint."""
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        checkpoint_path = (
            self.config.checkpoints_base / request_data.source_model_id / f"{request_data.checkpoint_id}.tar.gz"
        )

        self.backend.load_checkpoint(checkpoint_path, model_id)

        return types.LoadWeightsOutput(type="load_weights")

method process_save_weights

process_save_weights(model_id: str, request_data: types.SaveWeightsInput) -> types.SaveWeightsOutput | types.ErrorResponse

Saves a clean training checkpoint by converting the trimmed NNX graph to a pure dictionary before serialization, following official Flax docs.

Source code in skyrl/tinker/engine.py:542-562
    def process_save_weights(
        self, model_id: str, request_data: types.SaveWeightsInput
    ) -> types.SaveWeightsOutput | types.ErrorResponse:
        """
        Saves a clean training checkpoint by converting the trimmed NNX graph
        to a pure dictionary before serialization, following official Flax docs.
        """
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        checkpoint_id = request_data.path
        output_path = self.config.checkpoints_base / model_id / f"{checkpoint_id}.tar.gz"

        with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.TRAINING):
            self.backend.save_checkpoint(output_path, model_id)
            logger.info(f"Saved trimmed training checkpoint for model {model_id} to {output_path}")

        return types.SaveWeightsOutput(
            path=f"tinker://{model_id}/weights/{checkpoint_id}",
            type="save_weights",
        )

method process_save_weights_for_sampler

process_save_weights_for_sampler(model_id: str, request_data: types.SaveWeightsForSamplerInput) -> types.SaveWeightsForSamplerOutput | types.ErrorResponse

Process a save_weights_for_sampler request and save model weights.

Source code in skyrl/tinker/engine.py:564-594
    def process_save_weights_for_sampler(
        self, model_id: str, request_data: types.SaveWeightsForSamplerInput
    ) -> types.SaveWeightsForSamplerOutput | types.ErrorResponse:
        """Process a save_weights_for_sampler request and save model weights."""
        if not self.backend.has_model(model_id):
            return _model_not_found_error(model_id)

        # Make sure the user cannot store checkpoints in places like ../../<important file>
        checkpoint_id = Path(request_data.path).name
        output_path = self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz"

        # When the caller provides a sampling_session_seq_id the save is
        # transient — weights only need to reach the inference engines, not
        # disk.  Backends can skip the expensive write in that case.
        persist = request_data.sampling_session_seq_id is None

        with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER):
            self.backend.save_sampler_checkpoint(output_path, model_id, persist=persist)
            logger.info(f"Saved sampler checkpoint for model {model_id} to {output_path}")

        # Return path=None when using sampling_session_seq_id and seq_id (SDK expects this)
        if request_data.sampling_session_seq_id is not None and request_data.seq_id is not None:
            output_path_str = None
        else:
            output_path_str = f"tinker://{model_id}/{checkpoint_id}"

        return types.SaveWeightsForSamplerOutput(
            path=output_path_str,
            type="save_weights_for_sampler",
            sampling_session_id=request_data.sampling_session_id,
        )

method process_single_request

process_single_request(request_type: types.RequestType, model_id: str, request_data: dict) -> BaseModel
Source code in skyrl/tinker/engine.py:617-634
    def process_single_request(self, request_type: types.RequestType, model_id: str, request_data: dict) -> BaseModel:
        match request_type:
            case types.RequestType.CREATE_MODEL:
                return self.process_create_model(model_id, types.CreateModelInput.model_validate(request_data))
            case types.RequestType.OPTIM_STEP:
                return self.process_optim_step(model_id, types.OptimStepInput.model_validate(request_data))
            case types.RequestType.SAVE_WEIGHTS_FOR_SAMPLER:
                return self.process_save_weights_for_sampler(
                    model_id, types.SaveWeightsForSamplerInput.model_validate(request_data)
                )
            case types.RequestType.SAVE_WEIGHTS:
                return self.process_save_weights(model_id, types.SaveWeightsInput.model_validate(request_data))
            case types.RequestType.LOAD_WEIGHTS:
                return self.process_load_weights(model_id, types.LoadWeightsInput.model_validate(request_data))
            case types.RequestType.UNLOAD_MODEL:
                return self.process_unload_model(model_id, types.UnloadModelInput.model_validate(request_data))
            case _:
                raise ValueError(f"Unknown request type: {request_type}")

method process_single_requests

process_single_requests(requests: dict[str, tuple[str, types.RequestType, dict]])

Process a collection of single (non-batchable) requests.

Parameters:

NameTypeDescriptionDefault
requestsdict[str, tuple[str, RequestType, dict]]Dict mapping request_id to (model_id, request_type, request_data) tuplesrequired
Source code in skyrl/tinker/engine.py:636-653
    def process_single_requests(self, requests: dict[str, tuple[str, types.RequestType, dict]]):
        """Process a collection of single (non-batchable) requests.

        Args:
            requests: Dict mapping request_id to (model_id, request_type, request_data) tuples
        """
        if not requests:
            return
        results = {}
        for request_id, (model_id, request_type, request_data) in requests.items():
            with log_timing(f"process_single_request({request_type.value})"):
                try:
                    result = self.process_single_request(request_type, model_id, request_data)
                except Exception as e:
                    logger.exception(f"Error processing request {request_id}: {e}")
                    result = types.ErrorResponse(error=str(e), status="failed")
            results[request_id] = result
        self._complete_futures(results)

method process_batch_requests

process_batch_requests(requests: dict[str, tuple[str, BaseModel]], processor: Callable[[dict[str, tuple[str, BaseModel]]], dict[str, BaseModel]], name: str)

Process a batch of requests with error handling and future completion.

Parameters:

NameTypeDescriptionDefault
requestsdict[str, tuple[str, BaseModel]]Dict mapping request_id to (model_id, request_data) tuplesrequired
processorCallable\[[dict[str, tuple[str, BaseModel]]], dict[str, BaseModel]]Function that processes requests and returns results dictrequired
namestrName for loggingrequired
Source code in skyrl/tinker/engine.py:655-681
    def process_batch_requests(
        self,
        requests: dict[str, tuple[str, BaseModel]],
        processor: Callable[[dict[str, tuple[str, BaseModel]]], dict[str, BaseModel]],
        name: str,
    ):
        """Process a batch of requests with error handling and future completion.

        Args:
            requests: Dict mapping request_id to (model_id, request_data) tuples
            processor: Function that processes requests and returns results dict
            name: Name for logging
        """
        if not requests:
            return
        with log_timing(f"process_batch_requests({name}, n={len(requests)})"):
            try:
                error_results, valid_requests = self._filter_valid_requests(requests)
                if valid_requests:
                    results = processor(valid_requests)
                    results.update(error_results)
                else:
                    results = error_results
            except Exception as e:
                logger.exception(f"Error processing batch: {e}")
                results = {request_id: types.ErrorResponse(error=str(e), status="failed") for request_id in requests}
        self._complete_futures(results)

method process_pending_requests

process_pending_requests()

Main loop to process pending requests.

Source code in skyrl/tinker/engine.py:683-713
    def process_pending_requests(self):
        """Main loop to process pending requests."""
        while True:
            # Query for pending requests and extract data within session context
            with Session(self.db_engine) as session:
                # Use look-ahead scheduling to find batchable forward_backward and forward model passes
                forward_backward_requests = self.find_batchable_model_passes(
                    session, types.RequestType.FORWARD_BACKWARD
                )
                forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD)
                # Find pending sample requests that can be batched
                sample_requests = self.find_batchable_sample(session)
                # Get other pending requests (non forward_backward and non sampling)
                other_requests = self.find_single_requests(session)

            # Process batches outside of session context
            self.process_batch_requests(forward_backward_requests, self.process_forward_backward, "forward_backward")
            self.process_batch_requests(forward_requests, self.process_forward, "forward")
            self.process_batch_requests(sample_requests, self.process_sample, "sample")

            # Process other request types individually (in the future we can also batch independent optim_steps)
            self.process_single_requests(other_requests)

            # Periodically cleanup stale sessions (disabled if either config is negative)
            cleanup_enabled = self.config.session_cleanup_interval_sec >= 0 and self.config.session_timeout_sec >= 0
            if cleanup_enabled and time.time() - self._last_cleanup_time > self.config.session_cleanup_interval_sec:
                _ = self.cleanup_stale_sessions()
                self._last_cleanup_time = time.time()

            # Poll every 100ms
            time.sleep(0.1)

method run

run()

Entry point to start the engine.

Source code in skyrl/tinker/engine.py:715-718
    def run(self):
        """Entry point to start the engine."""
        logger.info("Starting background engine...")
        self.process_pending_requests()

On this page