SkyRL
API ReferenceSkyRLSkyRL-Train Backend

Generator

Generator API — GeneratorInterface, InferenceEngineInterface.

Core APIs

class GeneratorInterface

Bases: ABC

Functions:

NameDescription
generateGenerate trajectories for the input batch.
Source code in skyrl/train/generators/base.py:53-65
class GeneratorInterface(ABC):
    @abstractmethod
    async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
        """Generate trajectories for the input batch.

        Returns outputs in the same order as the input batch.

        Args:
            input_batch (GeneratorInput): Input batch
        Returns:
            GeneratorOutput: Generated trajectories
        """
        raise NotImplementedError

method async generate

generate(input_batch: GeneratorInput) -> GeneratorOutput

Generate trajectories for the input batch.

Returns outputs in the same order as the input batch.

Parameters:

NameTypeDescriptionDefault
input_batchGeneratorInputInput batchrequired

Returns: GeneratorOutput: Generated trajectories

Source code in skyrl/train/generators/base.py:54-65
    @abstractmethod
    async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
        """Generate trajectories for the input batch.

        Returns outputs in the same order as the input batch.

        Args:
            input_batch (GeneratorInput): Input batch
        Returns:
            GeneratorOutput: Generated trajectories
        """
        raise NotImplementedError

class InferenceEngineInterface

Bases: ABC

Functions:

NameDescription
generate
sampleGenerate multiple independent samples from a single prompt.
chat_completionHandles OpenAI-compatible HTTP endpoint.
completionHandles OpenAI-compatible HTTP endpoint.
wake_up
sleep
init_weight_update_communicatorInitialize weight update communicator from init info.
update_named_weights
teardown
reset_prefix_cache
tp_sizeReturn the tensor parallel size of this inference engine.
pp_sizeReturn the pipeline parallel size of this inference engine.
dp_sizeReturn the data parallel size of this inference engine.
abort_generationAbort all running and waiting requests, which make the ongoing requests return the
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:36-168
class InferenceEngineInterface(ABC):

    @abstractmethod
    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        raise NotImplementedError

    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.

        Args:
            prompt_token_ids: Token IDs for a single prompt.
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters.

        Returns:
            InferenceEngineOutput containing num_samples results:
                - response_ids: List of num_samples token ID lists
                - responses: List of num_samples decoded strings
                - stop_reasons: List of num_samples stop reasons
                - response_logprobs: Optional list of num_samples logprob lists
        """
        all_response_ids = []
        all_responses = []
        all_stop_reasons = []
        all_response_logprobs = []

        for _ in range(num_samples):
            input_batch: InferenceEngineInput = {
                "prompts": None,
                "prompt_token_ids": [prompt_token_ids],  # Wrap in list for batch of 1
                "sampling_params": sampling_params,
                "session_ids": None,
            }
            output = await self.generate(input_batch)

            # Extract single result from batch of 1
            all_response_ids.append(output["response_ids"][0])
            all_responses.append(output["responses"][0])
            all_stop_reasons.append(output["stop_reasons"][0])
            if output.get("response_logprobs") is not None:
                all_response_logprobs.append(output["response_logprobs"][0])

        return {
            "response_ids": all_response_ids,
            "responses": all_responses,
            "stop_reasons": all_stop_reasons,
            "response_logprobs": all_response_logprobs if all_response_logprobs else None,
        }

    @abstractmethod
    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a ChatCompletionRequest.
        Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

    @abstractmethod
    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a CompletionRequest.
        Returns a plain dict, either a CompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

    @abstractmethod
    async def wake_up(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

    @abstractmethod
    async def sleep(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

    @abstractmethod
    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator from init info.

        Args:
            init_info: WeightSyncInitInfo from the sender containing all info needed
                to create the appropriate receiver.
        """
        raise NotImplementedError()

    @abstractmethod
    async def update_named_weights(self, request: "WeightUpdateRequest"):
        raise NotImplementedError()

    @abstractmethod
    async def teardown(self):
        raise NotImplementedError

    @abstractmethod
    async def reset_prefix_cache(self):
        raise NotImplementedError

    @abstractmethod
    def tp_size(self) -> int:
        """Return the tensor parallel size of this inference engine."""
        raise NotImplementedError

    @abstractmethod
    def pp_size(self) -> int:
        """Return the pipeline parallel size of this inference engine."""
        raise NotImplementedError

    @abstractmethod
    def dp_size(self) -> int:
        """Return the data parallel size of this inference engine."""
        raise NotImplementedError

    @abstractmethod
    async def abort_generation(self) -> None:
        """
        Abort all running and waiting requests, which make the ongoing requests return the
        already-generated tokens with a stop_reason of "abort". If the request was waiting,
        it returns a response with zero completion tokens.
        """
        raise NotImplementedError

method async generate

generate(input_batch: InferenceEngineInput) -> InferenceEngineOutput
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:38-40
    @abstractmethod
    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        raise NotImplementedError

method abstractmethod sample

sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any]) -> InferenceEngineOutput

Generate multiple independent samples from a single prompt.

This method provides Tinker-compatible token-in/token-out sampling semantics.

Parameters:

NameTypeDescriptionDefault
prompt_token_idsList[int]Token IDs for a single prompt.required
num_samplesintNumber of independent samples to generate.required
sampling_paramsDict[str, Any]Sampling parameters.required

Returns:

TypeDescription
InferenceEngineOutputInferenceEngineOutput containing num_samples results: - response_ids: List of num_samples token ID lists - responses: List of num_samples decoded strings - stop_reasons: List of num_samples stop reasons - response_logprobs: Optional list of num_samples logprob lists
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:42-90
    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.

        Args:
            prompt_token_ids: Token IDs for a single prompt.
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters.

        Returns:
            InferenceEngineOutput containing num_samples results:
                - response_ids: List of num_samples token ID lists
                - responses: List of num_samples decoded strings
                - stop_reasons: List of num_samples stop reasons
                - response_logprobs: Optional list of num_samples logprob lists
        """
        all_response_ids = []
        all_responses = []
        all_stop_reasons = []
        all_response_logprobs = []

        for _ in range(num_samples):
            input_batch: InferenceEngineInput = {
                "prompts": None,
                "prompt_token_ids": [prompt_token_ids],  # Wrap in list for batch of 1
                "sampling_params": sampling_params,
                "session_ids": None,
            }
            output = await self.generate(input_batch)

            # Extract single result from batch of 1
            all_response_ids.append(output["response_ids"][0])
            all_responses.append(output["responses"][0])
            all_stop_reasons.append(output["stop_reasons"][0])
            if output.get("response_logprobs") is not None:
                all_response_logprobs.append(output["response_logprobs"][0])

        return {
            "response_ids": all_response_ids,
            "responses": all_responses,
            "stop_reasons": all_stop_reasons,
            "response_logprobs": all_response_logprobs if all_response_logprobs else None,
        }

method abstractmethod async chat_completion

chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]

Handles OpenAI-compatible HTTP endpoint.

Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a ChatCompletionRequest. Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:92-102
    @abstractmethod
    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a ChatCompletionRequest.
        Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

method abstractmethod async completion

completion(request_payload: Dict[str, Any]) -> Dict[str, Any]

Handles OpenAI-compatible HTTP endpoint.

Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a CompletionRequest. Returns a plain dict, either a CompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:104-114
    @abstractmethod
    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """Handles OpenAI-compatible HTTP endpoint.

        Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
        The request body will be used to construct a CompletionRequest.
        Returns a plain dict, either a CompletionResponse or an ErrorResponse.
        The specific fields of the response/request depend on the engine's backend (e.g. for vllm
        these are defined in vllm.entrypoints.openai.protocol).
        """
        raise NotImplementedError

method abstractmethod async wake_up

wake_up(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:116-118
    @abstractmethod
    async def wake_up(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

method abstractmethod async sleep

sleep(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:120-122
    @abstractmethod
    async def sleep(self, *args: Any, **kwargs: Any):
        raise NotImplementedError

method abstractmethod async init_weight_update_communicator

init_weight_update_communicator(init_info: WeightSyncInitInfo)

Initialize weight update communicator from init info.

Parameters:

NameTypeDescriptionDefault
init_infoWeightSyncInitInfoWeightSyncInitInfo from the sender containing all info needed to create the appropriate receiver.required
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:124-132
    @abstractmethod
    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator from init info.

        Args:
            init_info: WeightSyncInitInfo from the sender containing all info needed
                to create the appropriate receiver.
        """
        raise NotImplementedError()

method abstractmethod async update_named_weights

update_named_weights(request: WeightUpdateRequest)
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:134-136
    @abstractmethod
    async def update_named_weights(self, request: "WeightUpdateRequest"):
        raise NotImplementedError()

method abstractmethod async teardown

teardown()
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:138-140
    @abstractmethod
    async def teardown(self):
        raise NotImplementedError

method abstractmethod async reset_prefix_cache

reset_prefix_cache()
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:142-144
    @abstractmethod
    async def reset_prefix_cache(self):
        raise NotImplementedError

method abstractmethod tp_size

tp_size() -> int

Return the tensor parallel size of this inference engine.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:146-149
    @abstractmethod
    def tp_size(self) -> int:
        """Return the tensor parallel size of this inference engine."""
        raise NotImplementedError

method abstractmethod pp_size

pp_size() -> int

Return the pipeline parallel size of this inference engine.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:151-154
    @abstractmethod
    def pp_size(self) -> int:
        """Return the pipeline parallel size of this inference engine."""
        raise NotImplementedError

method abstractmethod dp_size

dp_size() -> int

Return the data parallel size of this inference engine.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:156-159
    @abstractmethod
    def dp_size(self) -> int:
        """Return the data parallel size of this inference engine."""
        raise NotImplementedError

method abstractmethod async abort_generation

abort_generation() -> None

Abort all running and waiting requests, which make the ongoing requests return the already-generated tokens with a stop_reason of "abort". If the request was waiting, it returns a response with zero completion tokens.

Source code in skyrl/backends/skyrl_train/inference_engines/base.py:161-168
    @abstractmethod
    async def abort_generation(self) -> None:
        """
        Abort all running and waiting requests, which make the ongoing requests return the
        already-generated tokens with a stop_reason of "abort". If the request was waiting,
        it returns a response with zero completion tokens.
        """
        raise NotImplementedError

class InferenceEngineClient

InferenceEngineClient(engines: List[InferenceEngineInterface], tokenizer: PreTrainedTokenizerBase, model_path: str, lora_cfg: SkyRLLoraConfig, inference_engine_cfg: InferenceEngineConfig)

Bases: InferenceEngineInterface

Client to talk to a set of InferenceEngines.

Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single engine.

Functions:

NameDescription
generate
sampleGenerate multiple independent samples from a single prompt.
chat_completion
completionHandles an OpenAI /completions request.
wake_up
sleep
init_weight_update_communicatorInitialize weight update communicator on all engines.
update_named_weights
reset_prefix_cache
teardown
tp_size
pp_size
dp_size
pause_generationPauses generation for all engines, intended for in-flight weight updates and partial rollouts.
resume_generationResumes generation for all engines, intended for in-flight weight updates and partial rollouts.
abort_generation

Attributes:

Parameters:

NameTypeDescriptionDefault
enginesList[InferenceEngineInterface]List[InferenceEngineInterface] - The inference engines, remote or local.required
tokenizerPreTrainedTokenizerBasePreTrainedTokenizerBase - The tokenizer to use.required
model_pathstrstr - The path to the model.required
lora_cfgSkyRLLoraConfigSkyRLLoraConfig - The LoRA configuration.required
inference_engine_cfgInferenceEngineConfigInferenceEngineConfig - The inference engine configuration.required
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:38-709
class InferenceEngineClient(InferenceEngineInterface):
    """
    Client to talk to a set of InferenceEngines.

    Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single
    engine.
    """

    def __init__(
        self,
        engines: List[InferenceEngineInterface],
        tokenizer: PreTrainedTokenizerBase,
        model_path: str,
        lora_cfg: SkyRLLoraConfig,
        inference_engine_cfg: InferenceEngineConfig,
    ):
        """
        Args:
            engines: List[InferenceEngineInterface] - The inference engines, remote or local.
            tokenizer: PreTrainedTokenizerBase - The tokenizer to use.
            model_path: str - The path to the model.
            lora_cfg: SkyRLLoraConfig - The LoRA configuration.
            inference_engine_cfg: InferenceEngineConfig - The inference engine configuration.
        """
        self.engines = engines
        self.tokenizer = tokenizer
        self.inference_engine_cfg = inference_engine_cfg
        # Use served_model_name if provided, otherwise fall back to model path.
        # served_model_name allows using a different model name for HTTP endpoint validation
        # than the actual model path. See ppo_base_config.yaml for details.
        served_model_name = inference_engine_cfg.served_model_name
        if served_model_name is not None:
            self.model_name = served_model_name
        else:
            self.model_name = model_path
        self.backend = inference_engine_cfg.backend
        self.enable_http_endpoint = inference_engine_cfg.enable_http_endpoint
        self.http_endpoint_host = inference_engine_cfg.http_endpoint_host
        self.http_endpoint_port = inference_engine_cfg.http_endpoint_port
        self.generation_paused_event = threading.Event()
        if self.enable_http_endpoint:
            self._spin_up_http_endpoint()

        logger.info(f"InferenceEngineClient initialized with {len(engines)} engines.")

    async def _run_on_all_engines(self, method_name: str, *args, **kwargs):
        """
        Call a method on all engines concurrently and gather the results.
        """
        assert len(self.engines) > 0, "No engines to call method on"

        awaitables = [getattr(engine, method_name)(*args, **kwargs) for engine in self.engines]
        return await asyncio.gather(*awaitables)

    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        # 0. Extract input
        prompts = input_batch.get("prompts")
        prompt_token_ids = input_batch.get("prompt_token_ids")
        session_ids = input_batch.get("session_ids")
        sampling_params = input_batch.get("sampling_params")

        if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
            raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
        if prompt_token_ids is None:
            prompt_token_ids = self.tokenizer.apply_chat_template(
                prompts,
                add_generation_prompt=True,
                return_dict=True,
                tokenize=True,
            )["input_ids"]

        num_prompts = len(prompt_token_ids)
        num_inference_engines = len(self.engines)

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_ids,
        )

        # We do a shortcut for non-batched requests, which can support pause/continue generation for
        # in-flight weight updates.
        if num_prompts == 1:
            # Route to a single engine for this single prompt and use retry flow.
            assert len(engine_idx_to_prompt_ids) == 1
            ((engine_idx, prompt_ids_list),) = engine_idx_to_prompt_ids.items()
            assert prompt_ids_list == [0], "Single prompt should map to index [0]"
            original_prompt_ids = prompt_token_ids[0]
            return await self._generate_single_with_retry(
                engine_idx=engine_idx,
                original_prompt_ids=original_prompt_ids,
                sampling_params=sampling_params,
            )

        # For batched generate(), pause/continue cannot be supported.
        if self.generation_paused_event.is_set():
            raise RuntimeError("pause_generation is unsupported for batched InferenceEngineClient.generate().")

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            # index prompt_token_ids with prompt_ids
            cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
            engine_input = InferenceEngineInput(
                prompt_token_ids=cur_prompt_token_ids,
                sampling_params=sampling_params,
            )
            tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Reconstruct output in original order
        n = len(prompt_token_ids)
        responses: list[str] = [""] * n
        stop_reasons: list[str] = [""] * n
        response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
        response_ids: List[List[int]] = [[] for _ in range(n)]
        # a bit hacky for now
        add_resp_logprobs = False

        for indices, result in zip(indices_list, results):
            for local_idx, original_idx in enumerate(indices):
                responses[original_idx] = result["responses"][local_idx]
                stop_reasons[original_idx] = result["stop_reasons"][local_idx]
                response_ids[original_idx] = result["response_ids"][local_idx]
                if result.get("response_logprobs", None):
                    add_resp_logprobs = True
                    response_logprobs[original_idx] = result["response_logprobs"][local_idx]

        return InferenceEngineOutput(
            responses=responses,
            stop_reasons=stop_reasons,
            response_ids=response_ids,
            response_logprobs=response_logprobs if add_resp_logprobs else None,
        )

    def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int:
        """Select an engine index for routing a request.

        Args:
            session_id: Optional session ID for consistent routing (e.g., conversation ID for chat).
                       If None, uses random load-balancing.

        Returns:
            Engine index to route the request to.
        """
        if session_id is None:
            return random.randint(0, len(self.engines) - 1)
        else:
            return hash_with_sha256(str(session_id)) % len(self.engines)

    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
        session_id: Optional[Union[str, int]] = None,
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.
        Generates num_samples independent completions from the same prompt.

        Args:
            prompt_token_ids: Token IDs for a single prompt (not batched).
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters (temperature, max_tokens, etc.).
            session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
                       If None, uses random load-balancing. Tinker API should pass None since
                       each sample() call is independent.

        Returns:
            InferenceEngineOutput containing num_samples results.
        """
        # Wait for generation to resume if paused (for weight updates)
        await self._wait_for_generation_to_resume()

        # Select engine (random if session_id is None, consistent hash otherwise)
        engine_idx = self._select_engine_idx(session_id)
        engine = self.engines[engine_idx]

        return await engine.sample(
            prompt_token_ids=prompt_token_ids,
            num_samples=num_samples,
            sampling_params=sampling_params,
        )

    async def _generate_single_with_retry(
        self, engine_idx: int, original_prompt_ids: List[int], sampling_params: Optional[Dict[str, Any]]
    ) -> InferenceEngineOutput:
        """
        Generate a single response with retry mechanism.

        This method is equivalent to `_chat_completion_with_retry()` but for the `generate()` codepath.
        We keep sending `generate` requests (with previous responses accumulated) until the finish_reason
        is not "abort". It is intended to be used in combination with `pause_generation()` and `resume_generation()` for
        in-flight weight updates and partial rollouts.

        This method is equivalent to a single `generate()` call if we do not use `pause_generation()`.

        Since we operate purely in the token space, it is token-in-token-out, unlike `_chat_completion_with_retry()`
        which re-encodes in each new request.

        For subsequent retry requests (`InferenceEngineInput`), we:
        - Update the `InferenceEngineInput.prompt_token_ids` with the accumulated output tokens.
        - Skip accumulating `InferenceEngineOutput.responses` since we decode the final output.
        - Adjust remaining max tokens if `max_tokens` or `max_completion_tokens` is present.

        For the final response, we return `InferenceEngineOutput` with:
        - `responses`: decoded at the end from `response_ids` if generation is completed in > 1 turns, otherwise
            the text response of the first turn.
        - `response_ids`: the accumulated output tokens
        - `stop_reasons`: the stop reason of the final response
        - `response_logprobs`: the accumulated logprobs
        """
        if sampling_params is None:
            sampling_params = {}

        # 1. First determine original max tokens key and value (if any)
        max_key = None
        if "max_tokens" in sampling_params:
            max_key = "max_tokens"
        elif "max_completion_tokens" in sampling_params:
            max_key = "max_completion_tokens"
        original_max_tokens: Optional[int] = sampling_params.get(max_key) if max_key else None

        # 2. Initialize fields we want to accumulate or update in each loop iteration
        accum_response_ids: List[int] = []
        accum_response_logprobs: List[float] = []
        stop_reason: str = "abort"

        # We only use it if generation is completed in one turn to maintain original behavior with no retry.
        text_response: Optional[str] = None
        num_turns = 0

        # 3. Loop until geneartion is completed.
        while stop_reason == "abort":
            await self._wait_for_generation_to_resume()

            # 3.1. Prepare the request payload.
            cur_sampling_params = sampling_params.copy()
            if original_max_tokens is not None:
                new_max_tokens = original_max_tokens - len(accum_response_ids)
                assert new_max_tokens >= 0, f"Expect new_max_tokens to be non-negative, but got {new_max_tokens}"
                cur_sampling_params[max_key] = new_max_tokens
            new_prompt_ids = original_prompt_ids + accum_response_ids
            engine_input = InferenceEngineInput(
                prompt_token_ids=[new_prompt_ids],
                sampling_params=cur_sampling_params,
            )

            # 3.2. Send the request.
            logger.debug(f"generate() request sent (including potential retries): {engine_input}")
            partial_response: InferenceEngineOutput = await self.engines[engine_idx].generate(engine_input)

            # 3.3. Parse the partial response.
            assert len(partial_response["response_ids"]) == 1, "Expected exactly one response."
            new_response_ids: List[int] = partial_response["response_ids"][0]
            text_response = partial_response["responses"][0]
            stop_reason = partial_response["stop_reasons"][0]
            new_response_logprobs: Optional[List[float]] = None
            new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None)
            if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0:
                new_response_logprobs = new_response_logprobs_list[0]

            # 3.4 Aborted without generating tokens, so partial_response is useless.
            if stop_reason == "abort" and len(new_response_ids) == 0:
                continue

            # 3.5 Accumulate outputs
            accum_response_ids.extend(new_response_ids)
            if new_response_logprobs is not None:
                accum_response_logprobs.extend(new_response_logprobs)
            num_turns += 1

        # 4. Build the final response and return.
        if num_turns == 1:
            final_text_response = text_response
        else:
            final_text_response = self.tokenizer.decode(accum_response_ids, skip_special_tokens=True)
        return InferenceEngineOutput(
            responses=[final_text_response],
            stop_reasons=[stop_reason],
            response_ids=[accum_response_ids],
            response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None,
        )

    async def _chat_completion_with_retry(
        self, engine_idx: int, original_request_payload: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Keep sending `chat_completion` requests (with previous responses accumulated) until the finish_reason is not
        "abort".

        The retry mechanism is intended to be used in combination with `pause_generation()` and `resume_generation()`
        for in-flight weight updates and partial rollouts.

        This method is equivalent to a single `chat_completion()` call if we do not use `pause_generation()`.

        For subsequent retry requests, we can reuse the original request with the following exceptions:
        - Update the last assistant message content to accumulated content, where the role uses the first non-empty
          response's role.
        - Set continue_final_message=True and add_generation_prompt=False.
        - Adjust remaining max tokens if `max_tokens` or `max_completion_tokens` is present.
        - If no tokens have been generated yet, resend the original request unchanged.

        For the final response, we maintain all the first non-empty response's fields (i.e. prefilled already),
        with the following exceptions:
        - Accumulate the following across retry requests:
          - `choices[0]["logprobs"]["content"]`
          - `choices[0]["token_ids"]`
          - `choices[0]["message"]["content"]`
        - Use the last response's finish_reason and stop_reason
        """
        original_request_json: Dict[str, Any] = original_request_payload.get("json", {}).copy()
        headers: Dict[str, str] = original_request_payload.get("headers", {}).copy()

        assert not original_request_json.get(
            "continue_final_message", False
        ), "continue_final_message must be False for /chat/completions requests"

        # Accumulated fields for building subsequent requests and final response. It is inplace-updated
        # in `_parse_partial_response_and_inplace_update_accum()`.
        accum = AccumulatedResponse()

        # First non-empty response (i.e. the response that prefilled the prompt) to copy meta from.
        base_response: Optional[Dict[str, Any]] = None

        # Determine original max tokens key and value (if any)
        max_key = None
        if "max_tokens" in original_request_json:
            max_key = "max_tokens"
        elif "max_completion_tokens" in original_request_json:
            max_key = "max_completion_tokens"
        orig_max_tokens: Optional[int] = original_request_json.get(max_key) if max_key else None

        # Fields to be updated in each loop iteration
        finish_reason: str = "abort"
        stop_reason: Optional[str] = None
        response_role: Optional[str] = None

        # 1. Loop until the generation is completed.
        while finish_reason == "abort":
            await self._wait_for_generation_to_resume()

            # 1.1. Prepare the request payload.
            cur_request_json = _prepare_retry_request(
                original_request_json=original_request_json,
                accum=accum,
                response_role=response_role,
                orig_max_tokens=orig_max_tokens,
                max_key=max_key,
            )

            # 1.2. Send the request.
            logger.debug(f"/chat/completions request sent (including potential retries): {cur_request_json}")
            partial_response = await self.engines[engine_idx].chat_completion(
                {"json": cur_request_json, "headers": headers}
            )

            # 1.2.1. Check for error response from engine (e.g., context length exceeded).
            # Error responses have "error" key instead of "choices", so return them directly
            # for the HTTP endpoint to handle with proper status codes.
            if "error" in partial_response or partial_response.get("object", "") == "error":
                return partial_response

            # 1.3. Parse partial response and in-place update accumulators.
            (
                finish_reason,
                stop_reason,
                response_role,
                aborted_without_generating,
            ) = _parse_partial_response_and_inplace_update_accum(
                partial_response=partial_response,
                accum=accum,
                response_role=response_role,
            )

            # 1.4. Aborted without generating tokens, so partial_response is useless.
            if aborted_without_generating:
                continue

            # At this point, either some tokens were generated and/or request completed with a non-"abort" finish_reason

            # 1.5. Update base response if it is the first non-empty response
            if base_response is None:
                if finish_reason != "abort":
                    # If we only made one request and it is not aborted, return the partial result directly.
                    # This is the codepath that will hit when we do not use `pause_generation()`
                    # or `resume_generation()`.
                    return partial_response
                # NOTE(Charlie): not doing deepcopy here to avoid copying large logprobs, so be careful when
                # modifying this.
                base_response = partial_response.copy()

        # 2. Build final response by combining fields
        assert base_response is not None, "Expected at least one non-empty response to build final response"
        return _build_final_response(
            base_response=base_response,
            accum=accum,
            finish_reason=finish_reason,
            stop_reason=stop_reason,
        )

    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        session_id = request_payload["json"].pop("session_id", None)
        if session_id is not None:
            assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
        engine_idx = self._select_engine_idx(session_id)

        # Always use the retry loop which also issues the first request inside
        return await self._chat_completion_with_retry(engine_idx, request_payload)

    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handles an OpenAI /completions request.

        Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
        (i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
        differently, based on whether it's a single or batched request, and whether `request["session_id"]`
        is provided. This is similar to `generate()` method.

        For single, we do the same routing logic as `chat_completion()`. For batched, we route by
        `request["session_id"]` if present, and if not we split evenly across engines.

        Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
        """
        if self.generation_paused_event.is_set():
            raise RuntimeError("pause_generation is unsupported for /completions requests.")
        body = request_payload.get("json", {})

        # NOTE(Charlie): do not reuse headers here as the single request may become various new requests
        headers = {"Content-Type": "application/json"}

        # 1. Postprocess prompt, session_id, and validate request.
        prompt = body.get("prompt")
        session_id_value = body.pop("session_id", None)
        ret = postprocess_completion_request(prompt, session_id_value)
        session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
        prompt: Union[List[List[int]], List[str]] = ret[1]
        if isinstance(session_id_list, ErrorResponse):
            return session_id_list.model_dump()

        num_prompts = len(prompt)
        num_inference_engines = len(self.engines)
        assert num_prompts > 0, "Number of prompts must be greater than 0"

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_id_list,
        )

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            cur_prompt = [prompt[i] for i in prompt_ids]
            # reuse the exact same request except for the prompt
            cur_json = dict(body)
            cur_json["prompt"] = cur_prompt
            coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
            tasks.append(asyncio.create_task(coro))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Check for errors.
        # results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
        # return an error response. That is, there is no partial success, following vLLM's behavior.
        for result in results:
            if "error" in result or result.get("object", "") == "error":
                error_details = result.get("error", result)
                error_code = error_details["code"]
                error_type = error_details["type"]
                error_message = error_details["message"]
                return ErrorResponse(
                    error=ErrorInfo(
                        message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
                        type=error_type,
                        code=error_code,
                    ),
                ).model_dump()

        # 4. Combine choices and preserve original order.
        # If there is only one result, we return it directly.
        if len(results) == 1:
            return results[0]

        # Use the first result as base response. There are some fields that cannot be shared
        # across sub-requests. For now it is just the usage field.
        final_response = dict(results[0])
        final_response["usage"] = aggregate_completion_usage_info(results, self.backend)

        # Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
        # vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
        combined_choices: list[Dict[str, Any]] = [None] * num_prompts
        for indices, result in zip(indices_list, results):
            # indices are the original prompt indices that the task's response corresponds to
            for local_idx, original_idx in enumerate(indices):
                choice = result["choices"][local_idx]
                choice["index"] = original_idx  # overwrite index with the global position
                combined_choices[original_idx] = choice

        # sanity check that the index is correct
        for new_idx in range(len(combined_choices)):
            assert combined_choices[new_idx]["index"] == new_idx

        final_response["choices"] = combined_choices
        return final_response

    async def wake_up(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("wake_up", *args, **kwargs)

    async def sleep(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("sleep", *args, **kwargs)

    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator on all engines.

        Args:
            init_info: WeightSyncInitInfo from the sender.

        Note:
            Per-engine adjustments (e.g., rank_offset for broadcast) are handled
            by init_info.for_engine().
        """
        tasks = []
        for i, engine in enumerate(self.engines):
            engine_init_info = init_info.for_engine(i, engine.tp_size(), engine.pp_size())
            tasks.append(engine.init_weight_update_communicator(engine_init_info))
        await asyncio.gather(*tasks)

    async def update_named_weights(self, request: WeightUpdateRequest):
        return await self._run_on_all_engines("update_named_weights", request=request)

    async def reset_prefix_cache(self):
        return await self._run_on_all_engines("reset_prefix_cache")

    async def teardown(self):
        return await self._run_on_all_engines("teardown")

    def tp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement tp_size()")

    def pp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement pp_size()")

    def dp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement dp_size()")

    # ----------------------------
    # Generation pause and resume
    # ----------------------------
    async def _wait_for_generation_to_resume(self) -> None:
        """Waits for generation to be resumed, intended for in-flight weight updates and partial rollouts."""
        while self.generation_paused_event.is_set():
            await asyncio.sleep(0.5)

    async def pause_generation(self) -> None:
        """
        Pauses generation for all engines, intended for in-flight weight updates and partial rollouts.

        Currently only supported for `/chat/completions` and not `/completions` or `generate()`.

        Both in-flight and incoming requests will be blocked until `resume_generation` is called.
        1. Set the paused event to avoid new requests from being submitted while aborting requests.
        2. Wait for a grace period to ensure all in-flight requests have entered the engine's
           scheduler and hence can be aborted. Otherwise, there can be requests already submitted
           but not yet entered the scheduler, which can miss the abort request.
        3. Finally, we abort requests on all engines. This will cause the requests sent from
           InferenceEngineClient to `InferenceEngineClient.engines` to return the already-generated tokens.
           The request to `InferenceEngineClient` will not yet return until requests are completed with
           stop reason that is not `abort`.
        """
        if self.generation_paused_event.is_set():
            raise RuntimeError("Generation is already paused, cannot pause again.")
        self.generation_paused_event.set()
        await asyncio.sleep(ABORT_GENERATION_GRACE_PERIOD_SECONDS)
        await self._run_on_all_engines("abort_generation")

    async def resume_generation(self) -> None:
        """
        Resumes generation for all engines, intended for in-flight weight updates and partial rollouts.

        Resume all in-flight requests with the previously-generated tokens, and unblock incoming requests
        that were blocked by `pause_generation()`.
        """
        if not self.generation_paused_event.is_set():
            raise RuntimeError("Generation is not paused, cannot resume.")
        self.generation_paused_event.clear()

    async def abort_generation(self) -> None:
        raise NotImplementedError(
            "InferenceEngineClient does not implement abort_generation(), but calls "
            "`abort_generation` on all engines in `pause_generation()`."
        )

    # ----------------------------
    # HTTP endpoint related methods
    # ----------------------------

    def __del__(self):
        """
        Destructor to shut down the HTTP endpoint if it was started.
        """
        # TODO(Charlie): __del__ is not guaranteed to be called in general. Add to `teardown` method
        # when the `_handle_termination` flow is implemented. See `skyrl_train/workers/worker.py`
        # comments on `_handle_termination` for more details.
        if (
            self.enable_http_endpoint
            and hasattr(
                self, "_server_thread"
            )  # don't want to shut down the server when it is pickled as a ray method argument.
            and self._server_thread is not None
        ):
            try:
                from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
                    shutdown_server,
                )

                shutdown_server(
                    host=self.http_endpoint_host,
                    port=self.http_endpoint_port,
                    max_wait_seconds=10,
                )
                if hasattr(self, "_server_thread") and self._server_thread.is_alive():
                    self._server_thread.join(timeout=10)
            except Exception as e:
                logger.error(f"Error shutting down HTTP endpoint: {e}")

    def __getstate__(self):
        """
        Override to avoid pickling the server thread and the threading.Event object, which are not picklable.
        Needed when passing InferenceEngineClient as an argument to async_run_ray_method(), mainly for
        invoking `init_weight_sync_state()` and `broadcast_to_inference_engines()`, which do
        not need these attributes.
        """
        state = self.__dict__.copy()
        state["_server_thread"] = None
        state["generation_paused_event"] = None
        return state

    def _spin_up_http_endpoint(self):
        from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
            serve,
            wait_for_server_ready,
        )

        self._server_thread = threading.Thread(
            target=serve,
            args=(self,),
            kwargs={
                "host": self.http_endpoint_host,
                "port": self.http_endpoint_port,
                "log_level": "warning",
            },
            daemon=True,
        )
        self._server_thread.start()
        wait_for_server_ready(
            host=self.http_endpoint_host,
            port=self.http_endpoint_port,
            max_wait_seconds=30,
        )
        logger.info(
            f"InferenceEngineClient HTTP endpoint started on {self.http_endpoint_host}:{self.http_endpoint_port}"
        )

attr engines

engines = engines

attr tokenizer

tokenizer = tokenizer

attr inference_engine_cfg

inference_engine_cfg = inference_engine_cfg

attr model_name

model_name = served_model_name

attr backend

backend = inference_engine_cfg.backend

attr enable_http_endpoint

enable_http_endpoint = inference_engine_cfg.enable_http_endpoint

attr http_endpoint_host

http_endpoint_host = inference_engine_cfg.http_endpoint_host

attr http_endpoint_port

http_endpoint_port = inference_engine_cfg.http_endpoint_port

attr generation_paused_event

generation_paused_event = threading.Event()

method async generate

generate(input_batch: InferenceEngineInput) -> InferenceEngineOutput
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:92-175
    async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
        # 0. Extract input
        prompts = input_batch.get("prompts")
        prompt_token_ids = input_batch.get("prompt_token_ids")
        session_ids = input_batch.get("session_ids")
        sampling_params = input_batch.get("sampling_params")

        if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
            raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
        if prompt_token_ids is None:
            prompt_token_ids = self.tokenizer.apply_chat_template(
                prompts,
                add_generation_prompt=True,
                return_dict=True,
                tokenize=True,
            )["input_ids"]

        num_prompts = len(prompt_token_ids)
        num_inference_engines = len(self.engines)

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_ids,
        )

        # We do a shortcut for non-batched requests, which can support pause/continue generation for
        # in-flight weight updates.
        if num_prompts == 1:
            # Route to a single engine for this single prompt and use retry flow.
            assert len(engine_idx_to_prompt_ids) == 1
            ((engine_idx, prompt_ids_list),) = engine_idx_to_prompt_ids.items()
            assert prompt_ids_list == [0], "Single prompt should map to index [0]"
            original_prompt_ids = prompt_token_ids[0]
            return await self._generate_single_with_retry(
                engine_idx=engine_idx,
                original_prompt_ids=original_prompt_ids,
                sampling_params=sampling_params,
            )

        # For batched generate(), pause/continue cannot be supported.
        if self.generation_paused_event.is_set():
            raise RuntimeError("pause_generation is unsupported for batched InferenceEngineClient.generate().")

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            # index prompt_token_ids with prompt_ids
            cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
            engine_input = InferenceEngineInput(
                prompt_token_ids=cur_prompt_token_ids,
                sampling_params=sampling_params,
            )
            tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Reconstruct output in original order
        n = len(prompt_token_ids)
        responses: list[str] = [""] * n
        stop_reasons: list[str] = [""] * n
        response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
        response_ids: List[List[int]] = [[] for _ in range(n)]
        # a bit hacky for now
        add_resp_logprobs = False

        for indices, result in zip(indices_list, results):
            for local_idx, original_idx in enumerate(indices):
                responses[original_idx] = result["responses"][local_idx]
                stop_reasons[original_idx] = result["stop_reasons"][local_idx]
                response_ids[original_idx] = result["response_ids"][local_idx]
                if result.get("response_logprobs", None):
                    add_resp_logprobs = True
                    response_logprobs[original_idx] = result["response_logprobs"][local_idx]

        return InferenceEngineOutput(
            responses=responses,
            stop_reasons=stop_reasons,
            response_ids=response_ids,
            response_logprobs=response_logprobs if add_resp_logprobs else None,
        )

method abstractmethod sample

sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any], session_id: Optional[Union[str, int]] = None) -> InferenceEngineOutput

Generate multiple independent samples from a single prompt.

This method provides Tinker-compatible token-in/token-out sampling semantics. Generates num_samples independent completions from the same prompt.

Parameters:

NameTypeDescriptionDefault
prompt_token_idsList[int]Token IDs for a single prompt (not batched).required
num_samplesintNumber of independent samples to generate.required
sampling_paramsDict[str, Any]Sampling parameters (temperature, max_tokens, etc.).required
session_idOptional[Union[str, int]]Optional session ID for consistent engine routing (e.g., conversation ID). If None, uses random load-balancing. Tinker API should pass None since each sample() call is independent.None

Returns:

TypeDescription
InferenceEngineOutputInferenceEngineOutput containing num_samples results.
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:192-226
    async def sample(
        self,
        prompt_token_ids: List[int],
        num_samples: int,
        sampling_params: Dict[str, Any],
        session_id: Optional[Union[str, int]] = None,
    ) -> InferenceEngineOutput:
        """Generate multiple independent samples from a single prompt.

        This method provides Tinker-compatible token-in/token-out sampling semantics.
        Generates num_samples independent completions from the same prompt.

        Args:
            prompt_token_ids: Token IDs for a single prompt (not batched).
            num_samples: Number of independent samples to generate.
            sampling_params: Sampling parameters (temperature, max_tokens, etc.).
            session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
                       If None, uses random load-balancing. Tinker API should pass None since
                       each sample() call is independent.

        Returns:
            InferenceEngineOutput containing num_samples results.
        """
        # Wait for generation to resume if paused (for weight updates)
        await self._wait_for_generation_to_resume()

        # Select engine (random if session_id is None, consistent hash otherwise)
        engine_idx = self._select_engine_idx(session_id)
        engine = self.engines[engine_idx]

        return await engine.sample(
            prompt_token_ids=prompt_token_ids,
            num_samples=num_samples,
            sampling_params=sampling_params,
        )

method abstractmethod async chat_completion

chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:445-452
    async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        session_id = request_payload["json"].pop("session_id", None)
        if session_id is not None:
            assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
        engine_idx = self._select_engine_idx(session_id)

        # Always use the retry loop which also issues the first request inside
        return await self._chat_completion_with_retry(engine_idx, request_payload)

method abstractmethod async completion

completion(request_payload: Dict[str, Any]) -> Dict[str, Any]

Handles an OpenAI /completions request.

Since request["prompt"] can be Union[list[int], list[list[int]], str, list[str]], (i.e. {batched, single} x {string, token IDs}), we need to route the request to engines differently, based on whether it's a single or batched request, and whether request["session_id"] is provided. This is similar to generate() method.

For single, we do the same routing logic as chat_completion(). For batched, we route by request["session_id"] if present, and if not we split evenly across engines.

Regardless, the order will be maintained, i.e. output["choices"][i] corresponds to request["prompt"][i].

Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:454-551
    async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
        """
        Handles an OpenAI /completions request.

        Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
        (i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
        differently, based on whether it's a single or batched request, and whether `request["session_id"]`
        is provided. This is similar to `generate()` method.

        For single, we do the same routing logic as `chat_completion()`. For batched, we route by
        `request["session_id"]` if present, and if not we split evenly across engines.

        Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
        """
        if self.generation_paused_event.is_set():
            raise RuntimeError("pause_generation is unsupported for /completions requests.")
        body = request_payload.get("json", {})

        # NOTE(Charlie): do not reuse headers here as the single request may become various new requests
        headers = {"Content-Type": "application/json"}

        # 1. Postprocess prompt, session_id, and validate request.
        prompt = body.get("prompt")
        session_id_value = body.pop("session_id", None)
        ret = postprocess_completion_request(prompt, session_id_value)
        session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
        prompt: Union[List[List[int]], List[str]] = ret[1]
        if isinstance(session_id_list, ErrorResponse):
            return session_id_list.model_dump()

        num_prompts = len(prompt)
        num_inference_engines = len(self.engines)
        assert num_prompts > 0, "Number of prompts must be greater than 0"

        # 1. Route prompts to engines
        engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
            num_prompts=num_prompts,
            num_inference_engines=num_inference_engines,
            session_ids=session_id_list,
        )

        # 2. Generate responses concurrently
        tasks: list[asyncio.Task] = []
        indices_list: list[list[int]] = []  # the original prompt indices that each task works on
        for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
            cur_prompt = [prompt[i] for i in prompt_ids]
            # reuse the exact same request except for the prompt
            cur_json = dict(body)
            cur_json["prompt"] = cur_prompt
            coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
            tasks.append(asyncio.create_task(coro))
            indices_list.append(prompt_ids)

        results = await asyncio.gather(*tasks)

        # 3. Check for errors.
        # results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
        # return an error response. That is, there is no partial success, following vLLM's behavior.
        for result in results:
            if "error" in result or result.get("object", "") == "error":
                error_details = result.get("error", result)
                error_code = error_details["code"]
                error_type = error_details["type"]
                error_message = error_details["message"]
                return ErrorResponse(
                    error=ErrorInfo(
                        message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
                        type=error_type,
                        code=error_code,
                    ),
                ).model_dump()

        # 4. Combine choices and preserve original order.
        # If there is only one result, we return it directly.
        if len(results) == 1:
            return results[0]

        # Use the first result as base response. There are some fields that cannot be shared
        # across sub-requests. For now it is just the usage field.
        final_response = dict(results[0])
        final_response["usage"] = aggregate_completion_usage_info(results, self.backend)

        # Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
        # vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
        combined_choices: list[Dict[str, Any]] = [None] * num_prompts
        for indices, result in zip(indices_list, results):
            # indices are the original prompt indices that the task's response corresponds to
            for local_idx, original_idx in enumerate(indices):
                choice = result["choices"][local_idx]
                choice["index"] = original_idx  # overwrite index with the global position
                combined_choices[original_idx] = choice

        # sanity check that the index is correct
        for new_idx in range(len(combined_choices)):
            assert combined_choices[new_idx]["index"] == new_idx

        final_response["choices"] = combined_choices
        return final_response

method abstractmethod async wake_up

wake_up(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:553-554
    async def wake_up(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("wake_up", *args, **kwargs)

method abstractmethod async sleep

sleep(*args: Any, **kwargs: Any)
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:556-557
    async def sleep(self, *args: Any, **kwargs: Any):
        return await self._run_on_all_engines("sleep", *args, **kwargs)

method abstractmethod async init_weight_update_communicator

init_weight_update_communicator(init_info: 'WeightSyncInitInfo')

Initialize weight update communicator on all engines.

Parameters:

NameTypeDescriptionDefault
init_info'WeightSyncInitInfo'WeightSyncInitInfo from the sender.required

Note:

Per-engine adjustments (e.g., rank_offset for broadcast) are handled by init_info.for_engine().

Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:559-573
    async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
        """Initialize weight update communicator on all engines.

        Args:
            init_info: WeightSyncInitInfo from the sender.

        Note:
            Per-engine adjustments (e.g., rank_offset for broadcast) are handled
            by init_info.for_engine().
        """
        tasks = []
        for i, engine in enumerate(self.engines):
            engine_init_info = init_info.for_engine(i, engine.tp_size(), engine.pp_size())
            tasks.append(engine.init_weight_update_communicator(engine_init_info))
        await asyncio.gather(*tasks)

method abstractmethod async update_named_weights

update_named_weights(request: WeightUpdateRequest)
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:575-576
    async def update_named_weights(self, request: WeightUpdateRequest):
        return await self._run_on_all_engines("update_named_weights", request=request)

method abstractmethod async reset_prefix_cache

reset_prefix_cache()
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:578-579
    async def reset_prefix_cache(self):
        return await self._run_on_all_engines("reset_prefix_cache")

method abstractmethod async teardown

teardown()
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:581-582
    async def teardown(self):
        return await self._run_on_all_engines("teardown")

method abstractmethod tp_size

tp_size() -> int
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:584-585
    def tp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement tp_size()")

method abstractmethod pp_size

pp_size() -> int
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:587-588
    def pp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement pp_size()")

method abstractmethod dp_size

dp_size() -> int
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:590-591
    def dp_size(self) -> int:
        raise NotImplementedError("InferenceEngineClient does not implement dp_size()")

method async pause_generation

pause_generation() -> None

Pauses generation for all engines, intended for in-flight weight updates and partial rollouts.

Currently only supported for /chat/completions and not /completions or generate().

Both in-flight and incoming requests will be blocked until resume_generation is called.

  1. Set the paused event to avoid new requests from being submitted while aborting requests.
  2. Wait for a grace period to ensure all in-flight requests have entered the engine's scheduler and hence can be aborted. Otherwise, there can be requests already submitted but not yet entered the scheduler, which can miss the abort request.
  3. Finally, we abort requests on all engines. This will cause the requests sent from InferenceEngineClient to InferenceEngineClient.engines to return the already-generated tokens. The request to InferenceEngineClient will not yet return until requests are completed with stop reason that is not abort.
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:601-621
    async def pause_generation(self) -> None:
        """
        Pauses generation for all engines, intended for in-flight weight updates and partial rollouts.

        Currently only supported for `/chat/completions` and not `/completions` or `generate()`.

        Both in-flight and incoming requests will be blocked until `resume_generation` is called.
        1. Set the paused event to avoid new requests from being submitted while aborting requests.
        2. Wait for a grace period to ensure all in-flight requests have entered the engine's
           scheduler and hence can be aborted. Otherwise, there can be requests already submitted
           but not yet entered the scheduler, which can miss the abort request.
        3. Finally, we abort requests on all engines. This will cause the requests sent from
           InferenceEngineClient to `InferenceEngineClient.engines` to return the already-generated tokens.
           The request to `InferenceEngineClient` will not yet return until requests are completed with
           stop reason that is not `abort`.
        """
        if self.generation_paused_event.is_set():
            raise RuntimeError("Generation is already paused, cannot pause again.")
        self.generation_paused_event.set()
        await asyncio.sleep(ABORT_GENERATION_GRACE_PERIOD_SECONDS)
        await self._run_on_all_engines("abort_generation")

method async resume_generation

resume_generation() -> None

Resumes generation for all engines, intended for in-flight weight updates and partial rollouts.

Resume all in-flight requests with the previously-generated tokens, and unblock incoming requests that were blocked by pause_generation().

Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:623-632
    async def resume_generation(self) -> None:
        """
        Resumes generation for all engines, intended for in-flight weight updates and partial rollouts.

        Resume all in-flight requests with the previously-generated tokens, and unblock incoming requests
        that were blocked by `pause_generation()`.
        """
        if not self.generation_paused_event.is_set():
            raise RuntimeError("Generation is not paused, cannot resume.")
        self.generation_paused_event.clear()

method abstractmethod async abort_generation

abort_generation() -> None
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:634-638
    async def abort_generation(self) -> None:
        raise NotImplementedError(
            "InferenceEngineClient does not implement abort_generation(), but calls "
            "`abort_generation` on all engines in `pause_generation()`."
        )

On this page