SkyRL
API ReferenceSkyRLSkyRL-Train Backend

Generator

Generator API — GeneratorInterface, InferenceEngineInterface.

Core APIs

class GeneratorInterface

Bases: ABC

Functions:

NameDescription
generateGenerate trajectories for the input batch.
Source code in skyrl/train/generators/base.py:59-71
class GeneratorInterface(ABC):
    @abstractmethod
    async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
        """Generate trajectories for the input batch.

        Returns outputs in the same order as the input batch.

        Args:
            input_batch (GeneratorInput): Input batch
        Returns:
            GeneratorOutput: Generated trajectories
        """
        raise NotImplementedError

method async generate

generate(input_batch: GeneratorInput) -> GeneratorOutput

Generate trajectories for the input batch.

Returns outputs in the same order as the input batch.

Parameters:

NameTypeDescriptionDefault
input_batchGeneratorInputInput batchrequired

Returns: GeneratorOutput: Generated trajectories

Source code in skyrl/train/generators/base.py:60-71
    @abstractmethod
    async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
        """Generate trajectories for the input batch.

        Returns outputs in the same order as the input batch.

        Args:
            input_batch (GeneratorInput): Input batch
        Returns:
            GeneratorOutput: Generated trajectories
        """
        raise NotImplementedError

class InferenceEngineInterface

Bases: ABC

Functions:

NameDescription
generate
sampleGenerate multiple independent samples from a single prompt.
chat_completionHandles OpenAI-compatible HTTP endpoint.
completionHandles OpenAI-compatible HTTP endpoint.
wake_up
sleep
init_weight_update_communicatorInitialize weight update communicator from init info.
update_named_weights
teardown
reset_prefix_cache
pause_generationPause generation, freezing in-flight requests so they can be resumed later.
resume_generationResume generation after a pause, continuing any frozen in-flight requests.
tp_sizeReturn the tensor parallel size of this inference engine.
pp_sizeReturn the pipeline parallel size of this inference engine.
dp_sizeReturn the data parallel size of this inference engine.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:49-186
class InferenceEngineInterface(ABC):

    @abstractmethod
    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        raise NotImplementedError

    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.

        Args:
            prompt_token_ids: Token IDs for a single prompt.
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters.

        Returns:
            InferenceEngineOutput containing num_samples results:
                - response_ids: List of num_samples token ID lists
                - responses: List of num_samples decoded strings
                - stop_reasons: List of num_samples stop reasons
                - response_logprobs: Optional list of num_samples logprob lists
        """
        all_response_ids = []
        all_responses = []
        all_stop_reasons = []
        all_response_logprobs = []
        all_rollout_expert_indices = []

        for _ in range(num_samples):
            input_batch: InferenceEngineInput = {
                "prompts": None,
                "prompt_token_ids": [prompt_token_ids],  # Wrap in list for batch of 1
                "sampling_params": sampling_params,
                "session_ids": None,
            }
            output = await self.generate(input_batch)

            # Extract single result from batch of 1
            all_response_ids.append(output["response_ids"][0])
            all_responses.append(output["responses"][0])
            all_stop_reasons.append(output["stop_reasons"][0])
            if output.get("response_logprobs") is not None:
                all_response_logprobs.append(output["response_logprobs"][0])
            if output.get("rollout_expert_indices") is not None:
                all_rollout_expert_indices.append(output["rollout_expert_indices"][0])

        return {
            "response_ids": all_response_ids,
            "responses": all_responses,
            "stop_reasons": all_stop_reasons,
            "response_logprobs": all_response_logprobs if all_response_logprobs else None,
            "rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
        }

    @abstractmethod
    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a ChatCompletionRequest.
        Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

    @abstractmethod
    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a CompletionRequest.
        Returns a plain dict, either a CompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

    @abstractmethod
    async def wake_up(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

    @abstractmethod
    async def sleep(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

    @abstractmethod
    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator from init info.

        Args:
            init_info: WeightSyncInitInfo from the sender containing all info needed
                to create the appropriate receiver.
        """
        raise NotImplementedError()

    @abstractmethod
    async def update_named_weights(self, request: "WeightUpdateRequest"):
        raise NotImplementedError()

    @abstractmethod
    async def teardown(self):
        raise NotImplementedError

    @abstractmethod
    async def reset_prefix_cache(self):
        raise NotImplementedError

    @abstractmethod
    async def pause_generation(self) -> None:
        """Pause generation, freezing in-flight requests so they can be resumed later."""
        raise NotImplementedError

    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume generation after a pause, continuing any frozen in-flight requests."""
        raise NotImplementedError

    @abstractmethod
    def tp_size(self) -> int:
        """Return the tensor parallel size of this inference engine."""
        raise NotImplementedError

    @abstractmethod
    def pp_size(self) -> int:
        """Return the pipeline parallel size of this inference engine."""
        raise NotImplementedError

    @abstractmethod
    def dp_size(self) -> int:
        """Return the data parallel size of this inference engine."""
        raise NotImplementedError

method async generate

generate(input_batch: InferenceEngineInput) -> InferenceEngineOutput
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:51-53
    @abstractmethod
    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        raise NotImplementedError

method abstractmethod sample

sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any]) -> InferenceEngineOutput

Generate multiple independent samples from a single prompt.

This method provides Tinker-compatible token-in/token-out sampling semantics.

Parameters:

NameTypeDescriptionDefault
prompt_token_idsList[int]Token IDs for a single prompt.required
num_samplesintNumber of independent samples to generate.required
sampling_paramsDict[str, Any]Sampling parameters.required

Returns:

TypeDescription
InferenceEngineOutputInferenceEngineOutput containing num_samples results: - response_ids: List of num_samples token ID lists - responses: List of num_samples decoded strings - stop_reasons: List of num_samples stop reasons - response_logprobs: Optional list of num_samples logprob lists
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:55-107
    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.

        Args:
            prompt_token_ids: Token IDs for a single prompt.
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters.

        Returns:
            InferenceEngineOutput containing num_samples results:
                - response_ids: List of num_samples token ID lists
                - responses: List of num_samples decoded strings
                - stop_reasons: List of num_samples stop reasons
                - response_logprobs: Optional list of num_samples logprob lists
        """
        all_response_ids = []
        all_responses = []
        all_stop_reasons = []
        all_response_logprobs = []
        all_rollout_expert_indices = []

        for _ in range(num_samples):
            input_batch: InferenceEngineInput = {
                "prompts": None,
                "prompt_token_ids": [prompt_token_ids],  # Wrap in list for batch of 1
                "sampling_params": sampling_params,
                "session_ids": None,
            }
            output = await self.generate(input_batch)

            # Extract single result from batch of 1
            all_response_ids.append(output["response_ids"][0])
            all_responses.append(output["responses"][0])
            all_stop_reasons.append(output["stop_reasons"][0])
            if output.get("response_logprobs") is not None:
                all_response_logprobs.append(output["response_logprobs"][0])
            if output.get("rollout_expert_indices") is not None:
                all_rollout_expert_indices.append(output["rollout_expert_indices"][0])

        return {
            "response_ids": all_response_ids,
            "responses": all_responses,
            "stop_reasons": all_stop_reasons,
            "response_logprobs": all_response_logprobs if all_response_logprobs else None,
            "rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
        }

method abstractmethod async chat_completion

chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]

Handles OpenAI-compatible HTTP endpoint.

Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a ChatCompletionRequest. Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:109-119
    @abstractmethod
    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a ChatCompletionRequest.
        Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

method abstractmethod async completion

completion(request_payload: Dict[str, Any]) -> Dict[str, Any]

Handles OpenAI-compatible HTTP endpoint.

Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a CompletionRequest. Returns a plain dict, either a CompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:121-131
    @abstractmethod
    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a CompletionRequest.
        Returns a plain dict, either a CompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

method abstractmethod async wake_up

wake_up(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:133-135
    @abstractmethod
    async def wake_up(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

method abstractmethod async sleep

sleep(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:137-139
    @abstractmethod
    async def sleep(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

method abstractmethod async init_weight_update_communicator

init_weight_update_communicator(init_info: WeightSyncInitInfo)

Initialize weight update communicator from init info.

Parameters:

NameTypeDescriptionDefault
init_infoWeightSyncInitInfoWeightSyncInitInfo from the sender containing all info needed to create the appropriate receiver.required
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:141-149
    @abstractmethod
    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator from init info.

        Args:
            init_info: WeightSyncInitInfo from the sender containing all info needed
                to create the appropriate receiver.
        """
        raise NotImplementedError()

method abstractmethod async update_named_weights

update_named_weights(request: WeightUpdateRequest)
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:151-153
    @abstractmethod
    async def update_named_weights(self, request: "WeightUpdateRequest"):
        raise NotImplementedError()

method abstractmethod async teardown

teardown()
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:155-157
    @abstractmethod
    async def teardown(self):
        raise NotImplementedError

method abstractmethod async reset_prefix_cache

reset_prefix_cache()
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:159-161
    @abstractmethod
    async def reset_prefix_cache(self):
        raise NotImplementedError

method abstractmethod async pause_generation

pause_generation() -> None

Pause generation, freezing in-flight requests so they can be resumed later.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:163-166
    @abstractmethod
    async def pause_generation(self) -> None:
        """Pause generation, freezing in-flight requests so they can be resumed later."""
        raise NotImplementedError

method abstractmethod async resume_generation

resume_generation() -> None

Resume generation after a pause, continuing any frozen in-flight requests.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:168-171
    @abstractmethod
    async def resume_generation(self) -> None:
        """Resume generation after a pause, continuing any frozen in-flight requests."""
        raise NotImplementedError

method abstractmethod tp_size

tp_size() -> int

Return the tensor parallel size of this inference engine.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:173-176
    @abstractmethod
    def tp_size(self) -> int:
        """Return the tensor parallel size of this inference engine."""
        raise NotImplementedError

method abstractmethod pp_size

pp_size() -> int

Return the pipeline parallel size of this inference engine.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:178-181
    @abstractmethod
    def pp_size(self) -> int:
        """Return the pipeline parallel size of this inference engine."""
        raise NotImplementedError

method abstractmethod dp_size

dp_size() -> int

Return the data parallel size of this inference engine.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:183-186
    @abstractmethod
    def dp_size(self) -> int:
        """Return the data parallel size of this inference engine."""
        raise NotImplementedError

class InferenceEngineClient

InferenceEngineClient(engines: List[InferenceEngineInterface], tokenizer: PreTrainedTokenizerBase, model_path: str, lora_cfg: SkyRLLoraConfig, inference_engine_cfg: InferenceEngineConfig)

Bases: InferenceEngineInterface

Client to talk to a set of InferenceEngines.

Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single engine.

Functions:

NameDescription
generate
sampleGenerate multiple independent samples from a single prompt.
chat_completion
completionHandles an OpenAI /completions request.
wake_up
sleep
init_weight_update_communicatorInitialize weight update communicator on all engines.
update_named_weights
reset_prefix_cache
teardown
tp_size
pp_size
dp_size
pause_generationPauses generation for all engines using vLLM's native keep mode.
resume_generationResumes generation for all engines after a keep-mode pause.

Attributes:

Parameters:

NameTypeDescriptionDefault
enginesList[InferenceEngineInterface]List[InferenceEngineInterface] - The inference engines, remote or local.required
tokenizerPreTrainedTokenizerBasePreTrainedTokenizerBase - The tokenizer to use.required
model_pathstrstr - The path to the model.required
lora_cfgSkyRLLoraConfigSkyRLLoraConfig - The LoRA configuration.required
inference_engine_cfgInferenceEngineConfigInferenceEngineConfig - The inference engine configuration.required
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:35-449
class InferenceEngineClient(InferenceEngineInterface):
    """
    Client to talk to a set of InferenceEngines.

    Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single
    engine.
    """

    def __init__(
        self,
        engines: List[InferenceEngineInterface],
        tokenizer: PreTrainedTokenizerBase,
        model_path: str,
        lora_cfg: SkyRLLoraConfig,
        inference_engine_cfg: InferenceEngineConfig,
    ):
        """
        Args:
            engines: List[InferenceEngineInterface] - The inference engines, remote or local.
            tokenizer: PreTrainedTokenizerBase - The tokenizer to use.
            model_path: str - The path to the model.
            lora_cfg: SkyRLLoraConfig - The LoRA configuration.
            inference_engine_cfg: InferenceEngineConfig - The inference engine configuration.
        """
        self.engines = engines
        self.tokenizer = tokenizer
        self.inference_engine_cfg = inference_engine_cfg
        # Use served_model_name if provided, otherwise fall back to model path.
        # served_model_name allows using a different model name for HTTP endpoint validation
        # than the actual model path. See ppo_base_config.yaml for details.
        served_model_name = inference_engine_cfg.served_model_name
        if served_model_name is not None:
            self.model_name = served_model_name
        else:
            self.model_name = model_path
        self.backend = inference_engine_cfg.backend
        self.enable_http_endpoint = inference_engine_cfg.enable_http_endpoint
        self.http_endpoint_host = inference_engine_cfg.http_endpoint_host
        self.http_endpoint_port = inference_engine_cfg.http_endpoint_port

        # we assume that dp_size is same for all engines
        dp_sizes = [engine.dp_size() for engine in self.engines]
        assert len(set(dp_sizes)) <= 1, f"Expected all engines to have the same DP size, got {dp_sizes}"
        if self.enable_http_endpoint:
            self._spin_up_http_endpoint()

        logger.info(f"InferenceEngineClient initialized with {len(engines)} engines.")

    async def _run_on_all_engines(self, method_name: str, *args, **kwargs):
        """
        Call a method on all engines concurrently and gather the results.
        """
        assert len(self.engines) > 0, "No engines to call method on"

        awaitables = [getattr(engine, method_name)(*args, **kwargs) for engine in self.engines]
        return await asyncio.gather(*awaitables)

    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        # 0. Extract input
        prompts = input_batch.get("prompts")
        prompt_token_ids = input_batch.get("prompt_token_ids")
        session_ids = input_batch.get("session_ids")
        sampling_params = input_batch.get("sampling_params")

        if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
            raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
        if prompt_token_ids is None:
            prompt_token_ids = self.tokenizer.apply_chat_template(
                prompts,
                add_generation_prompt=True,
                return_dict=False,
                tokenize=True,
            )

        num_prompts = len(prompt_token_ids)
        num_inference_engines = len(self.engines)

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_ids,
        )

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            # index prompt_token_ids with prompt_ids
            cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
            engine_input = InferenceEngineInput(
                prompt_token_ids=cur_prompt_token_ids,
                sampling_params=sampling_params,
            )
            tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Reconstruct output in original order
        n = len(prompt_token_ids)
        responses: list[str] = [""] * n
        stop_reasons: list[str] = [""] * n
        response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
        response_ids: List[List[int]] = [[] for _ in range(n)]
        rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)]
        # a bit hacky for now
        add_resp_logprobs = False
        add_rollout_expert_indices = False

        for indices, result in zip(indices_list, results):
            for local_idx, original_idx in enumerate(indices):
                responses[original_idx] = result["responses"][local_idx]
                stop_reasons[original_idx] = result["stop_reasons"][local_idx]
                response_ids[original_idx] = result["response_ids"][local_idx]
                if result.get("response_logprobs", None):
                    add_resp_logprobs = True
                    response_logprobs[original_idx] = result["response_logprobs"][local_idx]
                if result.get("rollout_expert_indices", None):
                    add_rollout_expert_indices = True
                    rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]

        return InferenceEngineOutput(
            responses=responses,
            stop_reasons=stop_reasons,
            response_ids=response_ids,
            response_logprobs=response_logprobs if add_resp_logprobs else None,
            rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
        )

    def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int:
        """Select an engine index for routing a request.

        Args:
            session_id: Optional session ID for consistent routing (e.g., conversation ID for chat).
                       If None, uses random load-balancing.

        Returns:
            Engine index to route the request to.
        """
        if session_id is None:
            return random.randint(0, len(self.engines) - 1)
        else:
            return hash_with_sha256(str(session_id)) % len(self.engines)

    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
        session_id: Optional[Union[str, int]] = None,
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.
        Generates num_samples independent completions from the same prompt.

        Args:
            prompt_token_ids: Token IDs for a single prompt (not batched).
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters (temperature, max_tokens, etc.).
            session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
                       If None, uses random load-balancing. Tinker API should pass None since
                       each sample() call is independent.

        Returns:
            InferenceEngineOutput containing num_samples results.
        """
        # Select engine (random if session_id is None, consistent hash otherwise)
        engine_idx = self._select_engine_idx(session_id)
        engine = self.engines[engine_idx]

        return await engine.sample(
            prompt_token_ids=prompt_token_ids,
            num_samples=num_samples,
            sampling_params=sampling_params,
        )

    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        session_id = request_payload["json"].pop("session_id", None)
        if session_id is not None:
            assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
        engine_idx = self._select_engine_idx(session_id)

        return await self.engines[engine_idx].chat_completion(request_payload)

    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handles an OpenAI /completions request.

        Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
        (i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
        differently, based on whether it's a single or batched request, and whether `request["session_id"]`
        is provided. This is similar to `generate()` method.

        For single, we do the same routing logic as `chat_completion()`. For batched, we route by
        `request["session_id"]` if present, and if not we split evenly across engines.

        Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
        """
        body = request_payload.get("json", {})

        # NOTE(Charlie): do not reuse headers here as the single request may become various new requests
        headers = {"Content-Type": "application/json"}

        # 1. Postprocess prompt, session_id, and validate request.
        prompt = body.get("prompt")
        session_id_value = body.pop("session_id", None)
        ret = postprocess_completion_request(prompt, session_id_value)
        session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
        prompt: Union[List[List[int]], List[str]] = ret[1]
        if isinstance(session_id_list, ErrorResponse):
            return session_id_list.model_dump()

        num_prompts = len(prompt)
        num_inference_engines = len(self.engines)
        assert num_prompts > 0, "Number of prompts must be greater than 0"

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_id_list,
        )

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            cur_prompt = [prompt[i] for i in prompt_ids]
            # reuse the exact same request except for the prompt
            cur_json = dict(body)
            cur_json["prompt"] = cur_prompt
            coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
            tasks.append(asyncio.create_task(coro))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Check for errors.
        # results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
        # return an error response. That is, there is no partial success, following vLLM's behavior.
        for result in results:
            if "error" in result or result.get("object", "") == "error":
                error_details = result.get("error", result)
                error_code = error_details["code"]
                error_type = error_details["type"]
                error_message = error_details["message"]
                return ErrorResponse(
                    error=ErrorInfo(
                        message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
                        type=error_type,
                        code=error_code,
                    ),
                ).model_dump()

        # 4. Combine choices and preserve original order.
        # If there is only one result, we return it directly.
        if len(results) == 1:
            return results[0]

        # Use the first result as base response. There are some fields that cannot be shared
        # across sub-requests. For now it is just the usage field.
        final_response = dict(results[0])
        final_response["usage"] = aggregate_completion_usage_info(results, self.backend)

        # Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
        # vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
        combined_choices: list[Dict[str, Any]] = [None] * num_prompts
        for indices, result in zip(indices_list, results):
            # indices are the original prompt indices that the task's response corresponds to
            for local_idx, original_idx in enumerate(indices):
                choice = result["choices"][local_idx]
                choice["index"] = original_idx  # overwrite index with the global position
                combined_choices[original_idx] = choice

        # sanity check that the index is correct
        for new_idx in range(len(combined_choices)):
            assert combined_choices[new_idx]["index"] == new_idx

        final_response["choices"] = combined_choices
        return final_response

    async def wake_up(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("wake_up", *args, **kwargs)

    async def sleep(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("sleep", *args, **kwargs)

    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator on all engines.

        Args:
            init_info: WeightSyncInitInfo from the sender.

        Note:
            Per-engine adjustments (e.g., rank_offset for broadcast) are handled
            by init_info.for_engine().
        """
        tasks = []
        for i, engine in enumerate(self.engines):
            # With vLLM, DP ranks are managed as separate engine instances
            # We want the index of truly separate vllm deployments i.e different dist worlds
            engine_idx = i // engine.dp_size()
            engine_init_info = init_info.for_engine(engine_idx, engine.tp_size(), engine.pp_size(), engine.dp_size())
            tasks.append(engine.init_weight_update_communicator(engine_init_info))
        await asyncio.gather(*tasks)

    async def update_named_weights(self, request: WeightUpdateRequest):
        return await self._run_on_all_engines("update_named_weights", request=request)

    async def reset_prefix_cache(self):
        return await self._run_on_all_engines("reset_prefix_cache")

    async def teardown(self):
        return await self._run_on_all_engines("teardown")

    def tp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement tp_size()")

    def pp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement pp_size()")

    def dp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement dp_size()")

    # ----------------------------
    # Generation pause and resume
    # ----------------------------
    async def pause_generation(self) -> None:
        """
        Pauses generation for all engines using vLLM's native keep mode.

        In-flight requests are frozen (not aborted) and will resume from where they left off
        when `resume_generation()` is called. New requests are blocked until resume.
        """
        await self._run_on_all_engines("pause_generation")

    async def resume_generation(self) -> None:
        """
        Resumes generation for all engines after a keep-mode pause.

        Frozen in-flight requests continue from where they left off, and new requests are unblocked.
        """
        await self._run_on_all_engines("resume_generation")

    # ----------------------------
    # HTTP endpoint related methods
    # ----------------------------

    def __del__(self):
        """
        Destructor to shut down the HTTP endpoint if it was started.
        """
        # TODO(Charlie): __del__ is not guaranteed to be called in general. Add to `teardown` method
        # when the `_handle_termination` flow is implemented. See `skyrl_train/workers/worker.py`
        # comments on `_handle_termination` for more details.
        if (
            self.enable_http_endpoint
            and hasattr(
                self, "_server_thread"
            )  # don't want to shut down the server when it is pickled as a ray method argument.
            and self._server_thread is not None
        ):
            try:
                from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
                    shutdown_server,
                )

                shutdown_server(
                    host=self.http_endpoint_host,
                    port=self.http_endpoint_port,
                    max_wait_seconds=10,
                )
                if hasattr(self, "_server_thread") and self._server_thread.is_alive():
                    self._server_thread.join(timeout=10)
            except Exception as e:
                logger.error(f"Error shutting down HTTP endpoint: {e}")

    def __getstate__(self):
        """
        Override to avoid pickling the server thread, which is not picklable.
        Needed when passing InferenceEngineClient as an argument to async_run_ray_method(), mainly for
        invoking `init_weight_sync_state()` and `broadcast_to_inference_engines()`, which do
        not need these attributes.
        """
        state = self.__dict__.copy()
        state["_server_thread"] = None
        return state

    def _spin_up_http_endpoint(self):
        from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
            serve,
            wait_for_server_ready,
        )

        self._server_thread = threading.Thread(
            target=serve,
            args=(self,),
            kwargs={
                "host": self.http_endpoint_host,
                "port": self.http_endpoint_port,
                "log_level": "warning",
            },
            daemon=True,
        )
        self._server_thread.start()
        wait_for_server_ready(
            host=self.http_endpoint_host,
            port=self.http_endpoint_port,
            max_wait_seconds=30,
        )
        logger.info(
            f"InferenceEngineClient HTTP endpoint started on {self.http_endpoint_host}:{self.http_endpoint_port}"
        )

attr engines

engines = engines

attr tokenizer

tokenizer = tokenizer

attr inference_engine_cfg

inference_engine_cfg = inference_engine_cfg

attr model_name

model_name = served_model_name

attr backend

backend = inference_engine_cfg.backend

attr enable_http_endpoint

enable_http_endpoint = inference_engine_cfg.enable_http_endpoint

attr http_endpoint_host

http_endpoint_host = inference_engine_cfg.http_endpoint_host

attr http_endpoint_port

http_endpoint_port = inference_engine_cfg.http_endpoint_port

method async generate

generate(input_batch: InferenceEngineInput) -> InferenceEngineOutput
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:92-163
    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        # 0. Extract input
        prompts = input_batch.get("prompts")
        prompt_token_ids = input_batch.get("prompt_token_ids")
        session_ids = input_batch.get("session_ids")
        sampling_params = input_batch.get("sampling_params")

        if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
            raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
        if prompt_token_ids is None:
            prompt_token_ids = self.tokenizer.apply_chat_template(
                prompts,
                add_generation_prompt=True,
                return_dict=False,
                tokenize=True,
            )

        num_prompts = len(prompt_token_ids)
        num_inference_engines = len(self.engines)

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_ids,
        )

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            # index prompt_token_ids with prompt_ids
            cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
            engine_input = InferenceEngineInput(
                prompt_token_ids=cur_prompt_token_ids,
                sampling_params=sampling_params,
            )
            tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Reconstruct output in original order
        n = len(prompt_token_ids)
        responses: list[str] = [""] * n
        stop_reasons: list[str] = [""] * n
        response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
        response_ids: List[List[int]] = [[] for _ in range(n)]
        rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)]
        # a bit hacky for now
        add_resp_logprobs = False
        add_rollout_expert_indices = False

        for indices, result in zip(indices_list, results):
            for local_idx, original_idx in enumerate(indices):
                responses[original_idx] = result["responses"][local_idx]
                stop_reasons[original_idx] = result["stop_reasons"][local_idx]
                response_ids[original_idx] = result["response_ids"][local_idx]
                if result.get("response_logprobs", None):
                    add_resp_logprobs = True
                    response_logprobs[original_idx] = result["response_logprobs"][local_idx]
                if result.get("rollout_expert_indices", None):
                    add_rollout_expert_indices = True
                    rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]

        return InferenceEngineOutput(
            responses=responses,
            stop_reasons=stop_reasons,
            response_ids=response_ids,
            response_logprobs=response_logprobs if add_resp_logprobs else None,
            rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
        )

method abstractmethod sample

sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any], session_id: Optional[Union[str, int]] = None) -> InferenceEngineOutput

Generate multiple independent samples from a single prompt.

This method provides Tinker-compatible token-in/token-out sampling semantics. Generates num_samples independent completions from the same prompt.

Parameters:

NameTypeDescriptionDefault
prompt_token_idsList[int]Token IDs for a single prompt (not batched).required
num_samplesintNumber of independent samples to generate.required
sampling_paramsDict[str, Any]Sampling parameters (temperature, max_tokens, etc.).required
session_idOptional[Union[str, int]]Optional session ID for consistent engine routing (e.g., conversation ID). If None, uses random load-balancing. Tinker API should pass None since each sample() call is independent.None

Returns:

TypeDescription
InferenceEngineOutputInferenceEngineOutput containing num_samples results.
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:180-211
    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
        session_id: Optional[Union[str, int]] = None,
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.
        Generates num_samples independent completions from the same prompt.

        Args:
            prompt_token_ids: Token IDs for a single prompt (not batched).
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters (temperature, max_tokens, etc.).
            session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
                       If None, uses random load-balancing. Tinker API should pass None since
                       each sample() call is independent.

        Returns:
            InferenceEngineOutput containing num_samples results.
        """
        # Select engine (random if session_id is None, consistent hash otherwise)
        engine_idx = self._select_engine_idx(session_id)
        engine = self.engines[engine_idx]

        return await engine.sample(
            prompt_token_ids=prompt_token_ids,
            num_samples=num_samples,
            sampling_params=sampling_params,
        )

method abstractmethod async chat_completion

chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:213-219
    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        session_id = request_payload["json"].pop("session_id", None)
        if session_id is not None:
            assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
        engine_idx = self._select_engine_idx(session_id)

        return await self.engines[engine_idx].chat_completion(request_payload)

method abstractmethod async completion

completion(request_payload: Dict[str, Any]) -> Dict[str, Any]

Handles an OpenAI /completions request.

Since request["prompt"] can be Union[list[int], list[list[int]], str, list[str]], (i.e. {batched, single} x {string, token IDs}), we need to route the request to engines differently, based on whether it's a single or batched request, and whether request["session_id"] is provided. This is similar to generate() method.

For single, we do the same routing logic as chat_completion(). For batched, we route by request["session_id"] if present, and if not we split evenly across engines.

Regardless, the order will be maintained, i.e. output["choices"][i] corresponds to request["prompt"][i].

Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:221-316
    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handles an OpenAI /completions request.

        Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
        (i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
        differently, based on whether it's a single or batched request, and whether `request["session_id"]`
        is provided. This is similar to `generate()` method.

        For single, we do the same routing logic as `chat_completion()`. For batched, we route by
        `request["session_id"]` if present, and if not we split evenly across engines.

        Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
        """
        body = request_payload.get("json", {})

        # NOTE(Charlie): do not reuse headers here as the single request may become various new requests
        headers = {"Content-Type": "application/json"}

        # 1. Postprocess prompt, session_id, and validate request.
        prompt = body.get("prompt")
        session_id_value = body.pop("session_id", None)
        ret = postprocess_completion_request(prompt, session_id_value)
        session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
        prompt: Union[List[List[int]], List[str]] = ret[1]
        if isinstance(session_id_list, ErrorResponse):
            return session_id_list.model_dump()

        num_prompts = len(prompt)
        num_inference_engines = len(self.engines)
        assert num_prompts > 0, "Number of prompts must be greater than 0"

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_id_list,
        )

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            cur_prompt = [prompt[i] for i in prompt_ids]
            # reuse the exact same request except for the prompt
            cur_json = dict(body)
            cur_json["prompt"] = cur_prompt
            coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
            tasks.append(asyncio.create_task(coro))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Check for errors.
        # results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
        # return an error response. That is, there is no partial success, following vLLM's behavior.
        for result in results:
            if "error" in result or result.get("object", "") == "error":
                error_details = result.get("error", result)
                error_code = error_details["code"]
                error_type = error_details["type"]
                error_message = error_details["message"]
                return ErrorResponse(
                    error=ErrorInfo(
                        message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
                        type=error_type,
                        code=error_code,
                    ),
                ).model_dump()

        # 4. Combine choices and preserve original order.
        # If there is only one result, we return it directly.
        if len(results) == 1:
            return results[0]

        # Use the first result as base response. There are some fields that cannot be shared
        # across sub-requests. For now it is just the usage field.
        final_response = dict(results[0])
        final_response["usage"] = aggregate_completion_usage_info(results, self.backend)

        # Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
        # vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
        combined_choices: list[Dict[str, Any]] = [None] * num_prompts
        for indices, result in zip(indices_list, results):
            # indices are the original prompt indices that the task's response corresponds to
            for local_idx, original_idx in enumerate(indices):
                choice = result["choices"][local_idx]
                choice["index"] = original_idx  # overwrite index with the global position
                combined_choices[original_idx] = choice

        # sanity check that the index is correct
        for new_idx in range(len(combined_choices)):
            assert combined_choices[new_idx]["index"] == new_idx

        final_response["choices"] = combined_choices
        return final_response

method abstractmethod async wake_up

wake_up(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:318-319
    async def wake_up(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("wake_up", *args, **kwargs)

method abstractmethod async sleep

sleep(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:321-322
    async def sleep(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("sleep", *args, **kwargs)

method abstractmethod async init_weight_update_communicator

init_weight_update_communicator(init_info: 'WeightSyncInitInfo')

Initialize weight update communicator on all engines.

Parameters:

NameTypeDescriptionDefault
init_info'WeightSyncInitInfo'WeightSyncInitInfo from the sender.required

Note:

Per-engine adjustments (e.g., rank_offset for broadcast) are handled by init_info.for_engine().

Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:324-341
    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator on all engines.

        Args:
            init_info: WeightSyncInitInfo from the sender.

        Note:
            Per-engine adjustments (e.g., rank_offset for broadcast) are handled
            by init_info.for_engine().
        """
        tasks = []
        for i, engine in enumerate(self.engines):
            # With vLLM, DP ranks are managed as separate engine instances
            # We want the index of truly separate vllm deployments i.e different dist worlds
            engine_idx = i // engine.dp_size()
            engine_init_info = init_info.for_engine(engine_idx, engine.tp_size(), engine.pp_size(), engine.dp_size())
            tasks.append(engine.init_weight_update_communicator(engine_init_info))
        await asyncio.gather(*tasks)

method abstractmethod async update_named_weights

update_named_weights(request: WeightUpdateRequest)
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:343-344
    async def update_named_weights(self, request: WeightUpdateRequest):
        return await self._run_on_all_engines("update_named_weights", request=request)

method abstractmethod async reset_prefix_cache

reset_prefix_cache()
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:346-347
    async def reset_prefix_cache(self):
        return await self._run_on_all_engines("reset_prefix_cache")

method abstractmethod async teardown

teardown()
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:349-350
    async def teardown(self):
        return await self._run_on_all_engines("teardown")

method abstractmethod tp_size

tp_size() -> int
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:352-353
    def tp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement tp_size()")

method abstractmethod pp_size

pp_size() -> int
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:355-356
    def pp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement pp_size()")

method abstractmethod dp_size

dp_size() -> int
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:358-359
    def dp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement dp_size()")

method abstractmethod async pause_generation

pause_generation() -> None

Pauses generation for all engines using vLLM's native keep mode.

In-flight requests are frozen (not aborted) and will resume from where they left off when resume_generation() is called. New requests are blocked until resume.

Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:364-371
    async def pause_generation(self) -> None:
        """
        Pauses generation for all engines using vLLM's native keep mode.

        In-flight requests are frozen (not aborted) and will resume from where they left off
        when `resume_generation()` is called. New requests are blocked until resume.
        """
        await self._run_on_all_engines("pause_generation")

method abstractmethod async resume_generation

resume_generation() -> None

Resumes generation for all engines after a keep-mode pause.

Frozen in-flight requests continue from where they left off, and new requests are unblocked.

Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:373-379
    async def resume_generation(self) -> None:
        """
        Resumes generation for all engines after a keep-mode pause.

        Frozen in-flight requests continue from where they left off, and new requests are unblocked.
        """
        await self._run_on_all_engines("resume_generation")

On this page