Generator
Generator API — GeneratorInterface, InferenceEngineInterface.
Core APIs
class GeneratorInterface
Bases: ABC
Functions:
| Name | Description |
|---|---|
generate | Generate trajectories for the input batch. |
Source code in skyrl/train/generators/base.py:53-65
class GeneratorInterface(ABC):
@abstractmethod
async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
"""Generate trajectories for the input batch.
Returns outputs in the same order as the input batch.
Args:
input_batch (GeneratorInput): Input batch
Returns:
GeneratorOutput: Generated trajectories
"""
raise NotImplementedErrormethod async generate
generate(input_batch: GeneratorInput) -> GeneratorOutputGenerate trajectories for the input batch.
Returns outputs in the same order as the input batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_batch | GeneratorInput | Input batch | required |
Returns: GeneratorOutput: Generated trajectories
Source code in skyrl/train/generators/base.py:54-65
@abstractmethod
async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
"""Generate trajectories for the input batch.
Returns outputs in the same order as the input batch.
Args:
input_batch (GeneratorInput): Input batch
Returns:
GeneratorOutput: Generated trajectories
"""
raise NotImplementedErrorclass InferenceEngineInterface
Bases: ABC
Functions:
| Name | Description |
|---|---|
generate | |
sample | Generate multiple independent samples from a single prompt. |
chat_completion | Handles OpenAI-compatible HTTP endpoint. |
completion | Handles OpenAI-compatible HTTP endpoint. |
wake_up | |
sleep | |
init_weight_update_communicator | Initialize weight update communicator from init info. |
update_named_weights | |
teardown | |
reset_prefix_cache | |
tp_size | Return the tensor parallel size of this inference engine. |
pp_size | Return the pipeline parallel size of this inference engine. |
dp_size | Return the data parallel size of this inference engine. |
abort_generation | Abort all running and waiting requests, which make the ongoing requests return the |
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:36-168
class InferenceEngineInterface(ABC):
@abstractmethod
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
raise NotImplementedError
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Args:
prompt_token_ids: Token IDs for a single prompt.
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters.
Returns:
InferenceEngineOutput containing num_samples results:
- response_ids: List of num_samples token ID lists
- responses: List of num_samples decoded strings
- stop_reasons: List of num_samples stop reasons
- response_logprobs: Optional list of num_samples logprob lists
"""
all_response_ids = []
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
for _ in range(num_samples):
input_batch: InferenceEngineInput = {
"prompts": None,
"prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1
"sampling_params": sampling_params,
"session_ids": None,
}
output = await self.generate(input_batch)
# Extract single result from batch of 1
all_response_ids.append(output["response_ids"][0])
all_responses.append(output["responses"][0])
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
}
@abstractmethod
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a ChatCompletionRequest.
Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedError
@abstractmethod
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a CompletionRequest.
Returns a plain dict, either a CompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedError
@abstractmethod
async def wake_up(self, *args: Any, **kwargs: Any):
raise NotImplementedError
@abstractmethod
async def sleep(self, *args: Any, **kwargs: Any):
raise NotImplementedError
@abstractmethod
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator from init info.
Args:
init_info: WeightSyncInitInfo from the sender containing all info needed
to create the appropriate receiver.
"""
raise NotImplementedError()
@abstractmethod
async def update_named_weights(self, request: "WeightUpdateRequest"):
raise NotImplementedError()
@abstractmethod
async def teardown(self):
raise NotImplementedError
@abstractmethod
async def reset_prefix_cache(self):
raise NotImplementedError
@abstractmethod
def tp_size(self) -> int:
"""Return the tensor parallel size of this inference engine."""
raise NotImplementedError
@abstractmethod
def pp_size(self) -> int:
"""Return the pipeline parallel size of this inference engine."""
raise NotImplementedError
@abstractmethod
def dp_size(self) -> int:
"""Return the data parallel size of this inference engine."""
raise NotImplementedError
@abstractmethod
async def abort_generation(self) -> None:
"""
Abort all running and waiting requests, which make the ongoing requests return the
already-generated tokens with a stop_reason of "abort". If the request was waiting,
it returns a response with zero completion tokens.
"""
raise NotImplementedErrormethod async generate
generate(input_batch: InferenceEngineInput) -> InferenceEngineOutputSource code in skyrl/backends/skyrl_train/inference_engines/base.py:38-40
@abstractmethod
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
raise NotImplementedErrormethod abstractmethod sample
sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any]) -> InferenceEngineOutputGenerate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt_token_ids | List[int] | Token IDs for a single prompt. | required |
num_samples | int | Number of independent samples to generate. | required |
sampling_params | Dict[str, Any] | Sampling parameters. | required |
Returns:
| Type | Description |
|---|---|
| InferenceEngineOutput | InferenceEngineOutput containing num_samples results: - response_ids: List of num_samples token ID lists - responses: List of num_samples decoded strings - stop_reasons: List of num_samples stop reasons - response_logprobs: Optional list of num_samples logprob lists |
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:42-90
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Args:
prompt_token_ids: Token IDs for a single prompt.
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters.
Returns:
InferenceEngineOutput containing num_samples results:
- response_ids: List of num_samples token ID lists
- responses: List of num_samples decoded strings
- stop_reasons: List of num_samples stop reasons
- response_logprobs: Optional list of num_samples logprob lists
"""
all_response_ids = []
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
for _ in range(num_samples):
input_batch: InferenceEngineInput = {
"prompts": None,
"prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1
"sampling_params": sampling_params,
"session_ids": None,
}
output = await self.generate(input_batch)
# Extract single result from batch of 1
all_response_ids.append(output["response_ids"][0])
all_responses.append(output["responses"][0])
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
}method abstractmethod async chat_completion
chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a ChatCompletionRequest. Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:92-102
@abstractmethod
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a ChatCompletionRequest.
Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedErrormethod abstractmethod async completion
completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a CompletionRequest. Returns a plain dict, either a CompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:104-114
@abstractmethod
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a CompletionRequest.
Returns a plain dict, either a CompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedErrormethod abstractmethod async wake_up
wake_up(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/base.py:116-118
@abstractmethod
async def wake_up(self, *args: Any, **kwargs: Any):
raise NotImplementedErrormethod abstractmethod async sleep
sleep(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/base.py:120-122
@abstractmethod
async def sleep(self, *args: Any, **kwargs: Any):
raise NotImplementedErrormethod abstractmethod async init_weight_update_communicator
init_weight_update_communicator(init_info: WeightSyncInitInfo)Initialize weight update communicator from init info.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_info | WeightSyncInitInfo | WeightSyncInitInfo from the sender containing all info needed to create the appropriate receiver. | required |
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:124-132
@abstractmethod
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator from init info.
Args:
init_info: WeightSyncInitInfo from the sender containing all info needed
to create the appropriate receiver.
"""
raise NotImplementedError()method abstractmethod async update_named_weights
update_named_weights(request: WeightUpdateRequest)Source code in skyrl/backends/skyrl_train/inference_engines/base.py:134-136
@abstractmethod
async def update_named_weights(self, request: "WeightUpdateRequest"):
raise NotImplementedError()method abstractmethod async teardown
teardown()Source code in skyrl/backends/skyrl_train/inference_engines/base.py:138-140
@abstractmethod
async def teardown(self):
raise NotImplementedErrormethod abstractmethod async reset_prefix_cache
reset_prefix_cache()Source code in skyrl/backends/skyrl_train/inference_engines/base.py:142-144
@abstractmethod
async def reset_prefix_cache(self):
raise NotImplementedErrormethod abstractmethod tp_size
tp_size() -> intReturn the tensor parallel size of this inference engine.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:146-149
@abstractmethod
def tp_size(self) -> int:
"""Return the tensor parallel size of this inference engine."""
raise NotImplementedErrormethod abstractmethod pp_size
pp_size() -> intReturn the pipeline parallel size of this inference engine.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:151-154
@abstractmethod
def pp_size(self) -> int:
"""Return the pipeline parallel size of this inference engine."""
raise NotImplementedErrormethod abstractmethod dp_size
dp_size() -> intReturn the data parallel size of this inference engine.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:156-159
@abstractmethod
def dp_size(self) -> int:
"""Return the data parallel size of this inference engine."""
raise NotImplementedErrormethod abstractmethod async abort_generation
abort_generation() -> NoneAbort all running and waiting requests, which make the ongoing requests return the already-generated tokens with a stop_reason of "abort". If the request was waiting, it returns a response with zero completion tokens.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:161-168
@abstractmethod
async def abort_generation(self) -> None:
"""
Abort all running and waiting requests, which make the ongoing requests return the
already-generated tokens with a stop_reason of "abort". If the request was waiting,
it returns a response with zero completion tokens.
"""
raise NotImplementedErrorclass InferenceEngineClient
InferenceEngineClient(engines: List[InferenceEngineInterface], tokenizer: PreTrainedTokenizerBase, model_path: str, lora_cfg: SkyRLLoraConfig, inference_engine_cfg: InferenceEngineConfig)Bases: InferenceEngineInterface
Client to talk to a set of InferenceEngines.
Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single engine.
Functions:
| Name | Description |
|---|---|
generate | |
sample | Generate multiple independent samples from a single prompt. |
chat_completion | |
completion | Handles an OpenAI /completions request. |
wake_up | |
sleep | |
init_weight_update_communicator | Initialize weight update communicator on all engines. |
update_named_weights | |
reset_prefix_cache | |
teardown | |
tp_size | |
pp_size | |
dp_size | |
pause_generation | Pauses generation for all engines, intended for in-flight weight updates and partial rollouts. |
resume_generation | Resumes generation for all engines, intended for in-flight weight updates and partial rollouts. |
abort_generation |
Attributes:
| Name | Type | Description |
|---|---|---|
engines | ||
tokenizer | ||
inference_engine_cfg | ||
model_name | ||
backend | ||
enable_http_endpoint | ||
http_endpoint_host | ||
http_endpoint_port | ||
generation_paused_event |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
engines | List[InferenceEngineInterface] | List[InferenceEngineInterface] - The inference engines, remote or local. | required |
tokenizer | PreTrainedTokenizerBase | PreTrainedTokenizerBase - The tokenizer to use. | required |
model_path | str | str - The path to the model. | required |
lora_cfg | SkyRLLoraConfig | SkyRLLoraConfig - The LoRA configuration. | required |
inference_engine_cfg | InferenceEngineConfig | InferenceEngineConfig - The inference engine configuration. | required |
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:38-709
class InferenceEngineClient(InferenceEngineInterface):
"""
Client to talk to a set of InferenceEngines.
Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single
engine.
"""
def __init__(
self,
engines: List[InferenceEngineInterface],
tokenizer: PreTrainedTokenizerBase,
model_path: str,
lora_cfg: SkyRLLoraConfig,
inference_engine_cfg: InferenceEngineConfig,
):
"""
Args:
engines: List[InferenceEngineInterface] - The inference engines, remote or local.
tokenizer: PreTrainedTokenizerBase - The tokenizer to use.
model_path: str - The path to the model.
lora_cfg: SkyRLLoraConfig - The LoRA configuration.
inference_engine_cfg: InferenceEngineConfig - The inference engine configuration.
"""
self.engines = engines
self.tokenizer = tokenizer
self.inference_engine_cfg = inference_engine_cfg
# Use served_model_name if provided, otherwise fall back to model path.
# served_model_name allows using a different model name for HTTP endpoint validation
# than the actual model path. See ppo_base_config.yaml for details.
served_model_name = inference_engine_cfg.served_model_name
if served_model_name is not None:
self.model_name = served_model_name
else:
self.model_name = model_path
self.backend = inference_engine_cfg.backend
self.enable_http_endpoint = inference_engine_cfg.enable_http_endpoint
self.http_endpoint_host = inference_engine_cfg.http_endpoint_host
self.http_endpoint_port = inference_engine_cfg.http_endpoint_port
self.generation_paused_event = threading.Event()
if self.enable_http_endpoint:
self._spin_up_http_endpoint()
logger.info(f"InferenceEngineClient initialized with {len(engines)} engines.")
async def _run_on_all_engines(self, method_name: str, *args, **kwargs):
"""
Call a method on all engines concurrently and gather the results.
"""
assert len(self.engines) > 0, "No engines to call method on"
awaitables = [getattr(engine, method_name)(*args, **kwargs) for engine in self.engines]
return await asyncio.gather(*awaitables)
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
# 0. Extract input
prompts = input_batch.get("prompts")
prompt_token_ids = input_batch.get("prompt_token_ids")
session_ids = input_batch.get("session_ids")
sampling_params = input_batch.get("sampling_params")
if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
if prompt_token_ids is None:
prompt_token_ids = self.tokenizer.apply_chat_template(
prompts,
add_generation_prompt=True,
return_dict=True,
tokenize=True,
)["input_ids"]
num_prompts = len(prompt_token_ids)
num_inference_engines = len(self.engines)
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_ids,
)
# We do a shortcut for non-batched requests, which can support pause/continue generation for
# in-flight weight updates.
if num_prompts == 1:
# Route to a single engine for this single prompt and use retry flow.
assert len(engine_idx_to_prompt_ids) == 1
((engine_idx, prompt_ids_list),) = engine_idx_to_prompt_ids.items()
assert prompt_ids_list == [0], "Single prompt should map to index [0]"
original_prompt_ids = prompt_token_ids[0]
return await self._generate_single_with_retry(
engine_idx=engine_idx,
original_prompt_ids=original_prompt_ids,
sampling_params=sampling_params,
)
# For batched generate(), pause/continue cannot be supported.
if self.generation_paused_event.is_set():
raise RuntimeError("pause_generation is unsupported for batched InferenceEngineClient.generate().")
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
# index prompt_token_ids with prompt_ids
cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
engine_input = InferenceEngineInput(
prompt_token_ids=cur_prompt_token_ids,
sampling_params=sampling_params,
)
tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Reconstruct output in original order
n = len(prompt_token_ids)
responses: list[str] = [""] * n
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
# a bit hacky for now
add_resp_logprobs = False
for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
responses[original_idx] = result["responses"][local_idx]
stop_reasons[original_idx] = result["stop_reasons"][local_idx]
response_ids[original_idx] = result["response_ids"][local_idx]
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
)
def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int:
"""Select an engine index for routing a request.
Args:
session_id: Optional session ID for consistent routing (e.g., conversation ID for chat).
If None, uses random load-balancing.
Returns:
Engine index to route the request to.
"""
if session_id is None:
return random.randint(0, len(self.engines) - 1)
else:
return hash_with_sha256(str(session_id)) % len(self.engines)
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
session_id: Optional[Union[str, int]] = None,
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Generates num_samples independent completions from the same prompt.
Args:
prompt_token_ids: Token IDs for a single prompt (not batched).
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters (temperature, max_tokens, etc.).
session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
If None, uses random load-balancing. Tinker API should pass None since
each sample() call is independent.
Returns:
InferenceEngineOutput containing num_samples results.
"""
# Wait for generation to resume if paused (for weight updates)
await self._wait_for_generation_to_resume()
# Select engine (random if session_id is None, consistent hash otherwise)
engine_idx = self._select_engine_idx(session_id)
engine = self.engines[engine_idx]
return await engine.sample(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
)
async def _generate_single_with_retry(
self, engine_idx: int, original_prompt_ids: List[int], sampling_params: Optional[Dict[str, Any]]
) -> InferenceEngineOutput:
"""
Generate a single response with retry mechanism.
This method is equivalent to `_chat_completion_with_retry()` but for the `generate()` codepath.
We keep sending `generate` requests (with previous responses accumulated) until the finish_reason
is not "abort". It is intended to be used in combination with `pause_generation()` and `resume_generation()` for
in-flight weight updates and partial rollouts.
This method is equivalent to a single `generate()` call if we do not use `pause_generation()`.
Since we operate purely in the token space, it is token-in-token-out, unlike `_chat_completion_with_retry()`
which re-encodes in each new request.
For subsequent retry requests (`InferenceEngineInput`), we:
- Update the `InferenceEngineInput.prompt_token_ids` with the accumulated output tokens.
- Skip accumulating `InferenceEngineOutput.responses` since we decode the final output.
- Adjust remaining max tokens if `max_tokens` or `max_completion_tokens` is present.
For the final response, we return `InferenceEngineOutput` with:
- `responses`: decoded at the end from `response_ids` if generation is completed in > 1 turns, otherwise
the text response of the first turn.
- `response_ids`: the accumulated output tokens
- `stop_reasons`: the stop reason of the final response
- `response_logprobs`: the accumulated logprobs
"""
if sampling_params is None:
sampling_params = {}
# 1. First determine original max tokens key and value (if any)
max_key = None
if "max_tokens" in sampling_params:
max_key = "max_tokens"
elif "max_completion_tokens" in sampling_params:
max_key = "max_completion_tokens"
original_max_tokens: Optional[int] = sampling_params.get(max_key) if max_key else None
# 2. Initialize fields we want to accumulate or update in each loop iteration
accum_response_ids: List[int] = []
accum_response_logprobs: List[float] = []
stop_reason: str = "abort"
# We only use it if generation is completed in one turn to maintain original behavior with no retry.
text_response: Optional[str] = None
num_turns = 0
# 3. Loop until geneartion is completed.
while stop_reason == "abort":
await self._wait_for_generation_to_resume()
# 3.1. Prepare the request payload.
cur_sampling_params = sampling_params.copy()
if original_max_tokens is not None:
new_max_tokens = original_max_tokens - len(accum_response_ids)
assert new_max_tokens >= 0, f"Expect new_max_tokens to be non-negative, but got {new_max_tokens}"
cur_sampling_params[max_key] = new_max_tokens
new_prompt_ids = original_prompt_ids + accum_response_ids
engine_input = InferenceEngineInput(
prompt_token_ids=[new_prompt_ids],
sampling_params=cur_sampling_params,
)
# 3.2. Send the request.
logger.debug(f"generate() request sent (including potential retries): {engine_input}")
partial_response: InferenceEngineOutput = await self.engines[engine_idx].generate(engine_input)
# 3.3. Parse the partial response.
assert len(partial_response["response_ids"]) == 1, "Expected exactly one response."
new_response_ids: List[int] = partial_response["response_ids"][0]
text_response = partial_response["responses"][0]
stop_reason = partial_response["stop_reasons"][0]
new_response_logprobs: Optional[List[float]] = None
new_response_logprobs_list: Optional[List[List[float]]] = partial_response.get("response_logprobs", None)
if new_response_logprobs_list is not None and len(new_response_logprobs_list) > 0:
new_response_logprobs = new_response_logprobs_list[0]
# 3.4 Aborted without generating tokens, so partial_response is useless.
if stop_reason == "abort" and len(new_response_ids) == 0:
continue
# 3.5 Accumulate outputs
accum_response_ids.extend(new_response_ids)
if new_response_logprobs is not None:
accum_response_logprobs.extend(new_response_logprobs)
num_turns += 1
# 4. Build the final response and return.
if num_turns == 1:
final_text_response = text_response
else:
final_text_response = self.tokenizer.decode(accum_response_ids, skip_special_tokens=True)
return InferenceEngineOutput(
responses=[final_text_response],
stop_reasons=[stop_reason],
response_ids=[accum_response_ids],
response_logprobs=[accum_response_logprobs] if len(accum_response_logprobs) > 0 else None,
)
async def _chat_completion_with_retry(
self, engine_idx: int, original_request_payload: Dict[str, Any]
) -> Dict[str, Any]:
"""
Keep sending `chat_completion` requests (with previous responses accumulated) until the finish_reason is not
"abort".
The retry mechanism is intended to be used in combination with `pause_generation()` and `resume_generation()`
for in-flight weight updates and partial rollouts.
This method is equivalent to a single `chat_completion()` call if we do not use `pause_generation()`.
For subsequent retry requests, we can reuse the original request with the following exceptions:
- Update the last assistant message content to accumulated content, where the role uses the first non-empty
response's role.
- Set continue_final_message=True and add_generation_prompt=False.
- Adjust remaining max tokens if `max_tokens` or `max_completion_tokens` is present.
- If no tokens have been generated yet, resend the original request unchanged.
For the final response, we maintain all the first non-empty response's fields (i.e. prefilled already),
with the following exceptions:
- Accumulate the following across retry requests:
- `choices[0]["logprobs"]["content"]`
- `choices[0]["token_ids"]`
- `choices[0]["message"]["content"]`
- Use the last response's finish_reason and stop_reason
"""
original_request_json: Dict[str, Any] = original_request_payload.get("json", {}).copy()
headers: Dict[str, str] = original_request_payload.get("headers", {}).copy()
assert not original_request_json.get(
"continue_final_message", False
), "continue_final_message must be False for /chat/completions requests"
# Accumulated fields for building subsequent requests and final response. It is inplace-updated
# in `_parse_partial_response_and_inplace_update_accum()`.
accum = AccumulatedResponse()
# First non-empty response (i.e. the response that prefilled the prompt) to copy meta from.
base_response: Optional[Dict[str, Any]] = None
# Determine original max tokens key and value (if any)
max_key = None
if "max_tokens" in original_request_json:
max_key = "max_tokens"
elif "max_completion_tokens" in original_request_json:
max_key = "max_completion_tokens"
orig_max_tokens: Optional[int] = original_request_json.get(max_key) if max_key else None
# Fields to be updated in each loop iteration
finish_reason: str = "abort"
stop_reason: Optional[str] = None
response_role: Optional[str] = None
# 1. Loop until the generation is completed.
while finish_reason == "abort":
await self._wait_for_generation_to_resume()
# 1.1. Prepare the request payload.
cur_request_json = _prepare_retry_request(
original_request_json=original_request_json,
accum=accum,
response_role=response_role,
orig_max_tokens=orig_max_tokens,
max_key=max_key,
)
# 1.2. Send the request.
logger.debug(f"/chat/completions request sent (including potential retries): {cur_request_json}")
partial_response = await self.engines[engine_idx].chat_completion(
{"json": cur_request_json, "headers": headers}
)
# 1.2.1. Check for error response from engine (e.g., context length exceeded).
# Error responses have "error" key instead of "choices", so return them directly
# for the HTTP endpoint to handle with proper status codes.
if "error" in partial_response or partial_response.get("object", "") == "error":
return partial_response
# 1.3. Parse partial response and in-place update accumulators.
(
finish_reason,
stop_reason,
response_role,
aborted_without_generating,
) = _parse_partial_response_and_inplace_update_accum(
partial_response=partial_response,
accum=accum,
response_role=response_role,
)
# 1.4. Aborted without generating tokens, so partial_response is useless.
if aborted_without_generating:
continue
# At this point, either some tokens were generated and/or request completed with a non-"abort" finish_reason
# 1.5. Update base response if it is the first non-empty response
if base_response is None:
if finish_reason != "abort":
# If we only made one request and it is not aborted, return the partial result directly.
# This is the codepath that will hit when we do not use `pause_generation()`
# or `resume_generation()`.
return partial_response
# NOTE(Charlie): not doing deepcopy here to avoid copying large logprobs, so be careful when
# modifying this.
base_response = partial_response.copy()
# 2. Build final response by combining fields
assert base_response is not None, "Expected at least one non-empty response to build final response"
return _build_final_response(
base_response=base_response,
accum=accum,
finish_reason=finish_reason,
stop_reason=stop_reason,
)
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
session_id = request_payload["json"].pop("session_id", None)
if session_id is not None:
assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
engine_idx = self._select_engine_idx(session_id)
# Always use the retry loop which also issues the first request inside
return await self._chat_completion_with_retry(engine_idx, request_payload)
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an OpenAI /completions request.
Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
(i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
differently, based on whether it's a single or batched request, and whether `request["session_id"]`
is provided. This is similar to `generate()` method.
For single, we do the same routing logic as `chat_completion()`. For batched, we route by
`request["session_id"]` if present, and if not we split evenly across engines.
Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
"""
if self.generation_paused_event.is_set():
raise RuntimeError("pause_generation is unsupported for /completions requests.")
body = request_payload.get("json", {})
# NOTE(Charlie): do not reuse headers here as the single request may become various new requests
headers = {"Content-Type": "application/json"}
# 1. Postprocess prompt, session_id, and validate request.
prompt = body.get("prompt")
session_id_value = body.pop("session_id", None)
ret = postprocess_completion_request(prompt, session_id_value)
session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
prompt: Union[List[List[int]], List[str]] = ret[1]
if isinstance(session_id_list, ErrorResponse):
return session_id_list.model_dump()
num_prompts = len(prompt)
num_inference_engines = len(self.engines)
assert num_prompts > 0, "Number of prompts must be greater than 0"
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_id_list,
)
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
cur_prompt = [prompt[i] for i in prompt_ids]
# reuse the exact same request except for the prompt
cur_json = dict(body)
cur_json["prompt"] = cur_prompt
coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
tasks.append(asyncio.create_task(coro))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Check for errors.
# results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
# return an error response. That is, there is no partial success, following vLLM's behavior.
for result in results:
if "error" in result or result.get("object", "") == "error":
error_details = result.get("error", result)
error_code = error_details["code"]
error_type = error_details["type"]
error_message = error_details["message"]
return ErrorResponse(
error=ErrorInfo(
message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
type=error_type,
code=error_code,
),
).model_dump()
# 4. Combine choices and preserve original order.
# If there is only one result, we return it directly.
if len(results) == 1:
return results[0]
# Use the first result as base response. There are some fields that cannot be shared
# across sub-requests. For now it is just the usage field.
final_response = dict(results[0])
final_response["usage"] = aggregate_completion_usage_info(results, self.backend)
# Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
# vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
combined_choices: list[Dict[str, Any]] = [None] * num_prompts
for indices, result in zip(indices_list, results):
# indices are the original prompt indices that the task's response corresponds to
for local_idx, original_idx in enumerate(indices):
choice = result["choices"][local_idx]
choice["index"] = original_idx # overwrite index with the global position
combined_choices[original_idx] = choice
# sanity check that the index is correct
for new_idx in range(len(combined_choices)):
assert combined_choices[new_idx]["index"] == new_idx
final_response["choices"] = combined_choices
return final_response
async def wake_up(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("wake_up", *args, **kwargs)
async def sleep(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("sleep", *args, **kwargs)
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator on all engines.
Args:
init_info: WeightSyncInitInfo from the sender.
Note:
Per-engine adjustments (e.g., rank_offset for broadcast) are handled
by init_info.for_engine().
"""
tasks = []
for i, engine in enumerate(self.engines):
engine_init_info = init_info.for_engine(i, engine.tp_size(), engine.pp_size())
tasks.append(engine.init_weight_update_communicator(engine_init_info))
await asyncio.gather(*tasks)
async def update_named_weights(self, request: WeightUpdateRequest):
return await self._run_on_all_engines("update_named_weights", request=request)
async def reset_prefix_cache(self):
return await self._run_on_all_engines("reset_prefix_cache")
async def teardown(self):
return await self._run_on_all_engines("teardown")
def tp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement tp_size()")
def pp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement pp_size()")
def dp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement dp_size()")
# ----------------------------
# Generation pause and resume
# ----------------------------
async def _wait_for_generation_to_resume(self) -> None:
"""Waits for generation to be resumed, intended for in-flight weight updates and partial rollouts."""
while self.generation_paused_event.is_set():
await asyncio.sleep(0.5)
async def pause_generation(self) -> None:
"""
Pauses generation for all engines, intended for in-flight weight updates and partial rollouts.
Currently only supported for `/chat/completions` and not `/completions` or `generate()`.
Both in-flight and incoming requests will be blocked until `resume_generation` is called.
1. Set the paused event to avoid new requests from being submitted while aborting requests.
2. Wait for a grace period to ensure all in-flight requests have entered the engine's
scheduler and hence can be aborted. Otherwise, there can be requests already submitted
but not yet entered the scheduler, which can miss the abort request.
3. Finally, we abort requests on all engines. This will cause the requests sent from
InferenceEngineClient to `InferenceEngineClient.engines` to return the already-generated tokens.
The request to `InferenceEngineClient` will not yet return until requests are completed with
stop reason that is not `abort`.
"""
if self.generation_paused_event.is_set():
raise RuntimeError("Generation is already paused, cannot pause again.")
self.generation_paused_event.set()
await asyncio.sleep(ABORT_GENERATION_GRACE_PERIOD_SECONDS)
await self._run_on_all_engines("abort_generation")
async def resume_generation(self) -> None:
"""
Resumes generation for all engines, intended for in-flight weight updates and partial rollouts.
Resume all in-flight requests with the previously-generated tokens, and unblock incoming requests
that were blocked by `pause_generation()`.
"""
if not self.generation_paused_event.is_set():
raise RuntimeError("Generation is not paused, cannot resume.")
self.generation_paused_event.clear()
async def abort_generation(self) -> None:
raise NotImplementedError(
"InferenceEngineClient does not implement abort_generation(), but calls "
"`abort_generation` on all engines in `pause_generation()`."
)
# ----------------------------
# HTTP endpoint related methods
# ----------------------------
def __del__(self):
"""
Destructor to shut down the HTTP endpoint if it was started.
"""
# TODO(Charlie): __del__ is not guaranteed to be called in general. Add to `teardown` method
# when the `_handle_termination` flow is implemented. See `skyrl_train/workers/worker.py`
# comments on `_handle_termination` for more details.
if (
self.enable_http_endpoint
and hasattr(
self, "_server_thread"
) # don't want to shut down the server when it is pickled as a ray method argument.
and self._server_thread is not None
):
try:
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
shutdown_server,
)
shutdown_server(
host=self.http_endpoint_host,
port=self.http_endpoint_port,
max_wait_seconds=10,
)
if hasattr(self, "_server_thread") and self._server_thread.is_alive():
self._server_thread.join(timeout=10)
except Exception as e:
logger.error(f"Error shutting down HTTP endpoint: {e}")
def __getstate__(self):
"""
Override to avoid pickling the server thread and the threading.Event object, which are not picklable.
Needed when passing InferenceEngineClient as an argument to async_run_ray_method(), mainly for
invoking `init_weight_sync_state()` and `broadcast_to_inference_engines()`, which do
not need these attributes.
"""
state = self.__dict__.copy()
state["_server_thread"] = None
state["generation_paused_event"] = None
return state
def _spin_up_http_endpoint(self):
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
serve,
wait_for_server_ready,
)
self._server_thread = threading.Thread(
target=serve,
args=(self,),
kwargs={
"host": self.http_endpoint_host,
"port": self.http_endpoint_port,
"log_level": "warning",
},
daemon=True,
)
self._server_thread.start()
wait_for_server_ready(
host=self.http_endpoint_host,
port=self.http_endpoint_port,
max_wait_seconds=30,
)
logger.info(
f"InferenceEngineClient HTTP endpoint started on {self.http_endpoint_host}:{self.http_endpoint_port}"
)attr engines
engines = enginesattr tokenizer
tokenizer = tokenizerattr inference_engine_cfg
inference_engine_cfg = inference_engine_cfgattr model_name
model_name = served_model_nameattr backend
backend = inference_engine_cfg.backendattr enable_http_endpoint
enable_http_endpoint = inference_engine_cfg.enable_http_endpointattr http_endpoint_host
http_endpoint_host = inference_engine_cfg.http_endpoint_hostattr http_endpoint_port
http_endpoint_port = inference_engine_cfg.http_endpoint_portattr generation_paused_event
generation_paused_event = threading.Event()method async generate
generate(input_batch: InferenceEngineInput) -> InferenceEngineOutputSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:92-175
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
# 0. Extract input
prompts = input_batch.get("prompts")
prompt_token_ids = input_batch.get("prompt_token_ids")
session_ids = input_batch.get("session_ids")
sampling_params = input_batch.get("sampling_params")
if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
if prompt_token_ids is None:
prompt_token_ids = self.tokenizer.apply_chat_template(
prompts,
add_generation_prompt=True,
return_dict=True,
tokenize=True,
)["input_ids"]
num_prompts = len(prompt_token_ids)
num_inference_engines = len(self.engines)
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_ids,
)
# We do a shortcut for non-batched requests, which can support pause/continue generation for
# in-flight weight updates.
if num_prompts == 1:
# Route to a single engine for this single prompt and use retry flow.
assert len(engine_idx_to_prompt_ids) == 1
((engine_idx, prompt_ids_list),) = engine_idx_to_prompt_ids.items()
assert prompt_ids_list == [0], "Single prompt should map to index [0]"
original_prompt_ids = prompt_token_ids[0]
return await self._generate_single_with_retry(
engine_idx=engine_idx,
original_prompt_ids=original_prompt_ids,
sampling_params=sampling_params,
)
# For batched generate(), pause/continue cannot be supported.
if self.generation_paused_event.is_set():
raise RuntimeError("pause_generation is unsupported for batched InferenceEngineClient.generate().")
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
# index prompt_token_ids with prompt_ids
cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
engine_input = InferenceEngineInput(
prompt_token_ids=cur_prompt_token_ids,
sampling_params=sampling_params,
)
tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Reconstruct output in original order
n = len(prompt_token_ids)
responses: list[str] = [""] * n
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
# a bit hacky for now
add_resp_logprobs = False
for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
responses[original_idx] = result["responses"][local_idx]
stop_reasons[original_idx] = result["stop_reasons"][local_idx]
response_ids[original_idx] = result["response_ids"][local_idx]
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
)method abstractmethod sample
sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any], session_id: Optional[Union[str, int]] = None) -> InferenceEngineOutputGenerate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics. Generates num_samples independent completions from the same prompt.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt_token_ids | List[int] | Token IDs for a single prompt (not batched). | required |
num_samples | int | Number of independent samples to generate. | required |
sampling_params | Dict[str, Any] | Sampling parameters (temperature, max_tokens, etc.). | required |
session_id | Optional[Union[str, int]] | Optional session ID for consistent engine routing (e.g., conversation ID). If None, uses random load-balancing. Tinker API should pass None since each sample() call is independent. | None |
Returns:
| Type | Description |
|---|---|
| InferenceEngineOutput | InferenceEngineOutput containing num_samples results. |
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:192-226
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
session_id: Optional[Union[str, int]] = None,
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Generates num_samples independent completions from the same prompt.
Args:
prompt_token_ids: Token IDs for a single prompt (not batched).
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters (temperature, max_tokens, etc.).
session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
If None, uses random load-balancing. Tinker API should pass None since
each sample() call is independent.
Returns:
InferenceEngineOutput containing num_samples results.
"""
# Wait for generation to resume if paused (for weight updates)
await self._wait_for_generation_to_resume()
# Select engine (random if session_id is None, consistent hash otherwise)
engine_idx = self._select_engine_idx(session_id)
engine = self.engines[engine_idx]
return await engine.sample(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
)method abstractmethod async chat_completion
chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:445-452
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
session_id = request_payload["json"].pop("session_id", None)
if session_id is not None:
assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
engine_idx = self._select_engine_idx(session_id)
# Always use the retry loop which also issues the first request inside
return await self._chat_completion_with_retry(engine_idx, request_payload)method abstractmethod async completion
completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Handles an OpenAI /completions request.
Since request["prompt"] can be Union[list[int], list[list[int]], str, list[str]],
(i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
differently, based on whether it's a single or batched request, and whether request["session_id"]
is provided. This is similar to generate() method.
For single, we do the same routing logic as chat_completion(). For batched, we route by
request["session_id"] if present, and if not we split evenly across engines.
Regardless, the order will be maintained, i.e. output["choices"][i] corresponds to request["prompt"][i].
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:454-551
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an OpenAI /completions request.
Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
(i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
differently, based on whether it's a single or batched request, and whether `request["session_id"]`
is provided. This is similar to `generate()` method.
For single, we do the same routing logic as `chat_completion()`. For batched, we route by
`request["session_id"]` if present, and if not we split evenly across engines.
Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
"""
if self.generation_paused_event.is_set():
raise RuntimeError("pause_generation is unsupported for /completions requests.")
body = request_payload.get("json", {})
# NOTE(Charlie): do not reuse headers here as the single request may become various new requests
headers = {"Content-Type": "application/json"}
# 1. Postprocess prompt, session_id, and validate request.
prompt = body.get("prompt")
session_id_value = body.pop("session_id", None)
ret = postprocess_completion_request(prompt, session_id_value)
session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
prompt: Union[List[List[int]], List[str]] = ret[1]
if isinstance(session_id_list, ErrorResponse):
return session_id_list.model_dump()
num_prompts = len(prompt)
num_inference_engines = len(self.engines)
assert num_prompts > 0, "Number of prompts must be greater than 0"
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_id_list,
)
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
cur_prompt = [prompt[i] for i in prompt_ids]
# reuse the exact same request except for the prompt
cur_json = dict(body)
cur_json["prompt"] = cur_prompt
coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
tasks.append(asyncio.create_task(coro))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Check for errors.
# results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
# return an error response. That is, there is no partial success, following vLLM's behavior.
for result in results:
if "error" in result or result.get("object", "") == "error":
error_details = result.get("error", result)
error_code = error_details["code"]
error_type = error_details["type"]
error_message = error_details["message"]
return ErrorResponse(
error=ErrorInfo(
message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
type=error_type,
code=error_code,
),
).model_dump()
# 4. Combine choices and preserve original order.
# If there is only one result, we return it directly.
if len(results) == 1:
return results[0]
# Use the first result as base response. There are some fields that cannot be shared
# across sub-requests. For now it is just the usage field.
final_response = dict(results[0])
final_response["usage"] = aggregate_completion_usage_info(results, self.backend)
# Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
# vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
combined_choices: list[Dict[str, Any]] = [None] * num_prompts
for indices, result in zip(indices_list, results):
# indices are the original prompt indices that the task's response corresponds to
for local_idx, original_idx in enumerate(indices):
choice = result["choices"][local_idx]
choice["index"] = original_idx # overwrite index with the global position
combined_choices[original_idx] = choice
# sanity check that the index is correct
for new_idx in range(len(combined_choices)):
assert combined_choices[new_idx]["index"] == new_idx
final_response["choices"] = combined_choices
return final_responsemethod abstractmethod async wake_up
wake_up(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:553-554
async def wake_up(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("wake_up", *args, **kwargs)method abstractmethod async sleep
sleep(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:556-557
async def sleep(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("sleep", *args, **kwargs)method abstractmethod async init_weight_update_communicator
init_weight_update_communicator(init_info: 'WeightSyncInitInfo')Initialize weight update communicator on all engines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_info | 'WeightSyncInitInfo' | WeightSyncInitInfo from the sender. | required |
Note:
Per-engine adjustments (e.g., rank_offset for broadcast) are handled by init_info.for_engine().
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:559-573
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator on all engines.
Args:
init_info: WeightSyncInitInfo from the sender.
Note:
Per-engine adjustments (e.g., rank_offset for broadcast) are handled
by init_info.for_engine().
"""
tasks = []
for i, engine in enumerate(self.engines):
engine_init_info = init_info.for_engine(i, engine.tp_size(), engine.pp_size())
tasks.append(engine.init_weight_update_communicator(engine_init_info))
await asyncio.gather(*tasks)method abstractmethod async update_named_weights
update_named_weights(request: WeightUpdateRequest)Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:575-576
async def update_named_weights(self, request: WeightUpdateRequest):
return await self._run_on_all_engines("update_named_weights", request=request)method abstractmethod async reset_prefix_cache
reset_prefix_cache()Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:578-579
async def reset_prefix_cache(self):
return await self._run_on_all_engines("reset_prefix_cache")method abstractmethod async teardown
teardown()Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:581-582
async def teardown(self):
return await self._run_on_all_engines("teardown")method abstractmethod tp_size
tp_size() -> intSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:584-585
def tp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement tp_size()")method abstractmethod pp_size
pp_size() -> intSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:587-588
def pp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement pp_size()")method abstractmethod dp_size
dp_size() -> intSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:590-591
def dp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement dp_size()")method async pause_generation
pause_generation() -> NonePauses generation for all engines, intended for in-flight weight updates and partial rollouts.
Currently only supported for /chat/completions and not /completions or generate().
Both in-flight and incoming requests will be blocked until resume_generation is called.
- Set the paused event to avoid new requests from being submitted while aborting requests.
- Wait for a grace period to ensure all in-flight requests have entered the engine's scheduler and hence can be aborted. Otherwise, there can be requests already submitted but not yet entered the scheduler, which can miss the abort request.
- Finally, we abort requests on all engines. This will cause the requests sent from
InferenceEngineClient to
InferenceEngineClient.enginesto return the already-generated tokens. The request toInferenceEngineClientwill not yet return until requests are completed with stop reason that is notabort.
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:601-621
async def pause_generation(self) -> None:
"""
Pauses generation for all engines, intended for in-flight weight updates and partial rollouts.
Currently only supported for `/chat/completions` and not `/completions` or `generate()`.
Both in-flight and incoming requests will be blocked until `resume_generation` is called.
1. Set the paused event to avoid new requests from being submitted while aborting requests.
2. Wait for a grace period to ensure all in-flight requests have entered the engine's
scheduler and hence can be aborted. Otherwise, there can be requests already submitted
but not yet entered the scheduler, which can miss the abort request.
3. Finally, we abort requests on all engines. This will cause the requests sent from
InferenceEngineClient to `InferenceEngineClient.engines` to return the already-generated tokens.
The request to `InferenceEngineClient` will not yet return until requests are completed with
stop reason that is not `abort`.
"""
if self.generation_paused_event.is_set():
raise RuntimeError("Generation is already paused, cannot pause again.")
self.generation_paused_event.set()
await asyncio.sleep(ABORT_GENERATION_GRACE_PERIOD_SECONDS)
await self._run_on_all_engines("abort_generation")method async resume_generation
resume_generation() -> NoneResumes generation for all engines, intended for in-flight weight updates and partial rollouts.
Resume all in-flight requests with the previously-generated tokens, and unblock incoming requests
that were blocked by pause_generation().
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:623-632
async def resume_generation(self) -> None:
"""
Resumes generation for all engines, intended for in-flight weight updates and partial rollouts.
Resume all in-flight requests with the previously-generated tokens, and unblock incoming requests
that were blocked by `pause_generation()`.
"""
if not self.generation_paused_event.is_set():
raise RuntimeError("Generation is not paused, cannot resume.")
self.generation_paused_event.clear()method abstractmethod async abort_generation
abort_generation() -> NoneSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:634-638
async def abort_generation(self) -> None:
raise NotImplementedError(
"InferenceEngineClient does not implement abort_generation(), but calls "
"`abort_generation` on all engines in `pause_generation()`."
)