Generator
Generator API — GeneratorInterface, InferenceEngineInterface.
Core APIs
class GeneratorInterface
Bases: ABC
Functions:
| Name | Description |
|---|---|
generate | Generate trajectories for the input batch. |
Source code in skyrl/train/generators/base.py:59-71
class GeneratorInterface(ABC):
@abstractmethod
async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
"""Generate trajectories for the input batch.
Returns outputs in the same order as the input batch.
Args:
input_batch (GeneratorInput): Input batch
Returns:
GeneratorOutput: Generated trajectories
"""
raise NotImplementedErrormethod async generate
generate(input_batch: GeneratorInput) -> GeneratorOutputGenerate trajectories for the input batch.
Returns outputs in the same order as the input batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_batch | GeneratorInput | Input batch | required |
Returns: GeneratorOutput: Generated trajectories
Source code in skyrl/train/generators/base.py:60-71
@abstractmethod
async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
"""Generate trajectories for the input batch.
Returns outputs in the same order as the input batch.
Args:
input_batch (GeneratorInput): Input batch
Returns:
GeneratorOutput: Generated trajectories
"""
raise NotImplementedErrorclass InferenceEngineInterface
Bases: ABC
Functions:
| Name | Description |
|---|---|
generate | |
sample | Generate multiple independent samples from a single prompt. |
chat_completion | Handles OpenAI-compatible HTTP endpoint. |
completion | Handles OpenAI-compatible HTTP endpoint. |
wake_up | |
sleep | |
init_weight_update_communicator | Initialize weight update communicator from init info. |
update_named_weights | |
teardown | |
reset_prefix_cache | |
pause_generation | Pause generation, freezing in-flight requests so they can be resumed later. |
resume_generation | Resume generation after a pause, continuing any frozen in-flight requests. |
tp_size | Return the tensor parallel size of this inference engine. |
pp_size | Return the pipeline parallel size of this inference engine. |
dp_size | Return the data parallel size of this inference engine. |
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:49-186
class InferenceEngineInterface(ABC):
@abstractmethod
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
raise NotImplementedError
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Args:
prompt_token_ids: Token IDs for a single prompt.
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters.
Returns:
InferenceEngineOutput containing num_samples results:
- response_ids: List of num_samples token ID lists
- responses: List of num_samples decoded strings
- stop_reasons: List of num_samples stop reasons
- response_logprobs: Optional list of num_samples logprob lists
"""
all_response_ids = []
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
all_rollout_expert_indices = []
for _ in range(num_samples):
input_batch: InferenceEngineInput = {
"prompts": None,
"prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1
"sampling_params": sampling_params,
"session_ids": None,
}
output = await self.generate(input_batch)
# Extract single result from batch of 1
all_response_ids.append(output["response_ids"][0])
all_responses.append(output["responses"][0])
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
if output.get("rollout_expert_indices") is not None:
all_rollout_expert_indices.append(output["rollout_expert_indices"][0])
return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
"rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
}
@abstractmethod
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a ChatCompletionRequest.
Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedError
@abstractmethod
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a CompletionRequest.
Returns a plain dict, either a CompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedError
@abstractmethod
async def wake_up(self, *args: Any, **kwargs: Any):
raise NotImplementedError
@abstractmethod
async def sleep(self, *args: Any, **kwargs: Any):
raise NotImplementedError
@abstractmethod
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator from init info.
Args:
init_info: WeightSyncInitInfo from the sender containing all info needed
to create the appropriate receiver.
"""
raise NotImplementedError()
@abstractmethod
async def update_named_weights(self, request: "WeightUpdateRequest"):
raise NotImplementedError()
@abstractmethod
async def teardown(self):
raise NotImplementedError
@abstractmethod
async def reset_prefix_cache(self):
raise NotImplementedError
@abstractmethod
async def pause_generation(self) -> None:
"""Pause generation, freezing in-flight requests so they can be resumed later."""
raise NotImplementedError
@abstractmethod
async def resume_generation(self) -> None:
"""Resume generation after a pause, continuing any frozen in-flight requests."""
raise NotImplementedError
@abstractmethod
def tp_size(self) -> int:
"""Return the tensor parallel size of this inference engine."""
raise NotImplementedError
@abstractmethod
def pp_size(self) -> int:
"""Return the pipeline parallel size of this inference engine."""
raise NotImplementedError
@abstractmethod
def dp_size(self) -> int:
"""Return the data parallel size of this inference engine."""
raise NotImplementedErrormethod async generate
generate(input_batch: InferenceEngineInput) -> InferenceEngineOutputSource code in skyrl/backends/skyrl_train/inference_engines/base.py:51-53
@abstractmethod
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
raise NotImplementedErrormethod abstractmethod sample
sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any]) -> InferenceEngineOutputGenerate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt_token_ids | List[int] | Token IDs for a single prompt. | required |
num_samples | int | Number of independent samples to generate. | required |
sampling_params | Dict[str, Any] | Sampling parameters. | required |
Returns:
| Type | Description |
|---|---|
| InferenceEngineOutput | InferenceEngineOutput containing num_samples results: - response_ids: List of num_samples token ID lists - responses: List of num_samples decoded strings - stop_reasons: List of num_samples stop reasons - response_logprobs: Optional list of num_samples logprob lists |
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:55-107
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Args:
prompt_token_ids: Token IDs for a single prompt.
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters.
Returns:
InferenceEngineOutput containing num_samples results:
- response_ids: List of num_samples token ID lists
- responses: List of num_samples decoded strings
- stop_reasons: List of num_samples stop reasons
- response_logprobs: Optional list of num_samples logprob lists
"""
all_response_ids = []
all_responses = []
all_stop_reasons = []
all_response_logprobs = []
all_rollout_expert_indices = []
for _ in range(num_samples):
input_batch: InferenceEngineInput = {
"prompts": None,
"prompt_token_ids": [prompt_token_ids], # Wrap in list for batch of 1
"sampling_params": sampling_params,
"session_ids": None,
}
output = await self.generate(input_batch)
# Extract single result from batch of 1
all_response_ids.append(output["response_ids"][0])
all_responses.append(output["responses"][0])
all_stop_reasons.append(output["stop_reasons"][0])
if output.get("response_logprobs") is not None:
all_response_logprobs.append(output["response_logprobs"][0])
if output.get("rollout_expert_indices") is not None:
all_rollout_expert_indices.append(output["rollout_expert_indices"][0])
return {
"response_ids": all_response_ids,
"responses": all_responses,
"stop_reasons": all_stop_reasons,
"response_logprobs": all_response_logprobs if all_response_logprobs else None,
"rollout_expert_indices": all_rollout_expert_indices if all_rollout_expert_indices else None,
}method abstractmethod async chat_completion
chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a ChatCompletionRequest. Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:109-119
@abstractmethod
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a ChatCompletionRequest.
Returns a plain dict, either a ChatCompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedErrormethod abstractmethod async completion
completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}. The request body will be used to construct a CompletionRequest. Returns a plain dict, either a CompletionResponse or an ErrorResponse. The specific fields of the response/request depend on the engine's backend (e.g. for vllm these are defined in vllm.entrypoints.openai.protocol).
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:121-131
@abstractmethod
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Handles OpenAI-compatible HTTP endpoint.
Accepts a JSON payload: {"json": <request-body>, "headers": <headers-dict>}.
The request body will be used to construct a CompletionRequest.
Returns a plain dict, either a CompletionResponse or an ErrorResponse.
The specific fields of the response/request depend on the engine's backend (e.g. for vllm
these are defined in vllm.entrypoints.openai.protocol).
"""
raise NotImplementedErrormethod abstractmethod async wake_up
wake_up(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/base.py:133-135
@abstractmethod
async def wake_up(self, *args: Any, **kwargs: Any):
raise NotImplementedErrormethod abstractmethod async sleep
sleep(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/base.py:137-139
@abstractmethod
async def sleep(self, *args: Any, **kwargs: Any):
raise NotImplementedErrormethod abstractmethod async init_weight_update_communicator
init_weight_update_communicator(init_info: WeightSyncInitInfo)Initialize weight update communicator from init info.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_info | WeightSyncInitInfo | WeightSyncInitInfo from the sender containing all info needed to create the appropriate receiver. | required |
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:141-149
@abstractmethod
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator from init info.
Args:
init_info: WeightSyncInitInfo from the sender containing all info needed
to create the appropriate receiver.
"""
raise NotImplementedError()method abstractmethod async update_named_weights
update_named_weights(request: WeightUpdateRequest)Source code in skyrl/backends/skyrl_train/inference_engines/base.py:151-153
@abstractmethod
async def update_named_weights(self, request: "WeightUpdateRequest"):
raise NotImplementedError()method abstractmethod async teardown
teardown()Source code in skyrl/backends/skyrl_train/inference_engines/base.py:155-157
@abstractmethod
async def teardown(self):
raise NotImplementedErrormethod abstractmethod async reset_prefix_cache
reset_prefix_cache()Source code in skyrl/backends/skyrl_train/inference_engines/base.py:159-161
@abstractmethod
async def reset_prefix_cache(self):
raise NotImplementedErrormethod abstractmethod async pause_generation
pause_generation() -> NonePause generation, freezing in-flight requests so they can be resumed later.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:163-166
@abstractmethod
async def pause_generation(self) -> None:
"""Pause generation, freezing in-flight requests so they can be resumed later."""
raise NotImplementedErrormethod abstractmethod async resume_generation
resume_generation() -> NoneResume generation after a pause, continuing any frozen in-flight requests.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:168-171
@abstractmethod
async def resume_generation(self) -> None:
"""Resume generation after a pause, continuing any frozen in-flight requests."""
raise NotImplementedErrormethod abstractmethod tp_size
tp_size() -> intReturn the tensor parallel size of this inference engine.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:173-176
@abstractmethod
def tp_size(self) -> int:
"""Return the tensor parallel size of this inference engine."""
raise NotImplementedErrormethod abstractmethod pp_size
pp_size() -> intReturn the pipeline parallel size of this inference engine.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:178-181
@abstractmethod
def pp_size(self) -> int:
"""Return the pipeline parallel size of this inference engine."""
raise NotImplementedErrormethod abstractmethod dp_size
dp_size() -> intReturn the data parallel size of this inference engine.
Source code in skyrl/backends/skyrl_train/inference_engines/base.py:183-186
@abstractmethod
def dp_size(self) -> int:
"""Return the data parallel size of this inference engine."""
raise NotImplementedErrorclass InferenceEngineClient
InferenceEngineClient(engines: List[InferenceEngineInterface], tokenizer: PreTrainedTokenizerBase, model_path: str, lora_cfg: SkyRLLoraConfig, inference_engine_cfg: InferenceEngineConfig)Bases: InferenceEngineInterface
Client to talk to a set of InferenceEngines.
Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single engine.
Functions:
| Name | Description |
|---|---|
generate | |
sample | Generate multiple independent samples from a single prompt. |
chat_completion | |
completion | Handles an OpenAI /completions request. |
wake_up | |
sleep | |
init_weight_update_communicator | Initialize weight update communicator on all engines. |
update_named_weights | |
reset_prefix_cache | |
teardown | |
tp_size | |
pp_size | |
dp_size | |
pause_generation | Pauses generation for all engines using vLLM's native keep mode. |
resume_generation | Resumes generation for all engines after a keep-mode pause. |
Attributes:
| Name | Type | Description |
|---|---|---|
engines | ||
tokenizer | ||
inference_engine_cfg | ||
model_name | ||
backend | ||
enable_http_endpoint | ||
http_endpoint_host | ||
http_endpoint_port |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
engines | List[InferenceEngineInterface] | List[InferenceEngineInterface] - The inference engines, remote or local. | required |
tokenizer | PreTrainedTokenizerBase | PreTrainedTokenizerBase - The tokenizer to use. | required |
model_path | str | str - The path to the model. | required |
lora_cfg | SkyRLLoraConfig | SkyRLLoraConfig - The LoRA configuration. | required |
inference_engine_cfg | InferenceEngineConfig | InferenceEngineConfig - The inference engine configuration. | required |
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:35-449
class InferenceEngineClient(InferenceEngineInterface):
"""
Client to talk to a set of InferenceEngines.
Note that InferenceEngineClient sub-classes InferenceEngineInterface so it can be used as if talking to a single
engine.
"""
def __init__(
self,
engines: List[InferenceEngineInterface],
tokenizer: PreTrainedTokenizerBase,
model_path: str,
lora_cfg: SkyRLLoraConfig,
inference_engine_cfg: InferenceEngineConfig,
):
"""
Args:
engines: List[InferenceEngineInterface] - The inference engines, remote or local.
tokenizer: PreTrainedTokenizerBase - The tokenizer to use.
model_path: str - The path to the model.
lora_cfg: SkyRLLoraConfig - The LoRA configuration.
inference_engine_cfg: InferenceEngineConfig - The inference engine configuration.
"""
self.engines = engines
self.tokenizer = tokenizer
self.inference_engine_cfg = inference_engine_cfg
# Use served_model_name if provided, otherwise fall back to model path.
# served_model_name allows using a different model name for HTTP endpoint validation
# than the actual model path. See ppo_base_config.yaml for details.
served_model_name = inference_engine_cfg.served_model_name
if served_model_name is not None:
self.model_name = served_model_name
else:
self.model_name = model_path
self.backend = inference_engine_cfg.backend
self.enable_http_endpoint = inference_engine_cfg.enable_http_endpoint
self.http_endpoint_host = inference_engine_cfg.http_endpoint_host
self.http_endpoint_port = inference_engine_cfg.http_endpoint_port
# we assume that dp_size is same for all engines
dp_sizes = [engine.dp_size() for engine in self.engines]
assert len(set(dp_sizes)) <= 1, f"Expected all engines to have the same DP size, got {dp_sizes}"
if self.enable_http_endpoint:
self._spin_up_http_endpoint()
logger.info(f"InferenceEngineClient initialized with {len(engines)} engines.")
async def _run_on_all_engines(self, method_name: str, *args, **kwargs):
"""
Call a method on all engines concurrently and gather the results.
"""
assert len(self.engines) > 0, "No engines to call method on"
awaitables = [getattr(engine, method_name)(*args, **kwargs) for engine in self.engines]
return await asyncio.gather(*awaitables)
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
# 0. Extract input
prompts = input_batch.get("prompts")
prompt_token_ids = input_batch.get("prompt_token_ids")
session_ids = input_batch.get("session_ids")
sampling_params = input_batch.get("sampling_params")
if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
if prompt_token_ids is None:
prompt_token_ids = self.tokenizer.apply_chat_template(
prompts,
add_generation_prompt=True,
return_dict=False,
tokenize=True,
)
num_prompts = len(prompt_token_ids)
num_inference_engines = len(self.engines)
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_ids,
)
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
# index prompt_token_ids with prompt_ids
cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
engine_input = InferenceEngineInput(
prompt_token_ids=cur_prompt_token_ids,
sampling_params=sampling_params,
)
tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Reconstruct output in original order
n = len(prompt_token_ids)
responses: list[str] = [""] * n
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)]
# a bit hacky for now
add_resp_logprobs = False
add_rollout_expert_indices = False
for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
responses[original_idx] = result["responses"][local_idx]
stop_reasons[original_idx] = result["stop_reasons"][local_idx]
response_ids[original_idx] = result["response_ids"][local_idx]
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
if result.get("rollout_expert_indices", None):
add_rollout_expert_indices = True
rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]
return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
)
def _select_engine_idx(self, session_id: Optional[Union[str, int]] = None) -> int:
"""Select an engine index for routing a request.
Args:
session_id: Optional session ID for consistent routing (e.g., conversation ID for chat).
If None, uses random load-balancing.
Returns:
Engine index to route the request to.
"""
if session_id is None:
return random.randint(0, len(self.engines) - 1)
else:
return hash_with_sha256(str(session_id)) % len(self.engines)
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
session_id: Optional[Union[str, int]] = None,
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Generates num_samples independent completions from the same prompt.
Args:
prompt_token_ids: Token IDs for a single prompt (not batched).
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters (temperature, max_tokens, etc.).
session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
If None, uses random load-balancing. Tinker API should pass None since
each sample() call is independent.
Returns:
InferenceEngineOutput containing num_samples results.
"""
# Select engine (random if session_id is None, consistent hash otherwise)
engine_idx = self._select_engine_idx(session_id)
engine = self.engines[engine_idx]
return await engine.sample(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
)
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
session_id = request_payload["json"].pop("session_id", None)
if session_id is not None:
assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
engine_idx = self._select_engine_idx(session_id)
return await self.engines[engine_idx].chat_completion(request_payload)
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an OpenAI /completions request.
Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
(i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
differently, based on whether it's a single or batched request, and whether `request["session_id"]`
is provided. This is similar to `generate()` method.
For single, we do the same routing logic as `chat_completion()`. For batched, we route by
`request["session_id"]` if present, and if not we split evenly across engines.
Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
"""
body = request_payload.get("json", {})
# NOTE(Charlie): do not reuse headers here as the single request may become various new requests
headers = {"Content-Type": "application/json"}
# 1. Postprocess prompt, session_id, and validate request.
prompt = body.get("prompt")
session_id_value = body.pop("session_id", None)
ret = postprocess_completion_request(prompt, session_id_value)
session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
prompt: Union[List[List[int]], List[str]] = ret[1]
if isinstance(session_id_list, ErrorResponse):
return session_id_list.model_dump()
num_prompts = len(prompt)
num_inference_engines = len(self.engines)
assert num_prompts > 0, "Number of prompts must be greater than 0"
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_id_list,
)
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
cur_prompt = [prompt[i] for i in prompt_ids]
# reuse the exact same request except for the prompt
cur_json = dict(body)
cur_json["prompt"] = cur_prompt
coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
tasks.append(asyncio.create_task(coro))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Check for errors.
# results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
# return an error response. That is, there is no partial success, following vLLM's behavior.
for result in results:
if "error" in result or result.get("object", "") == "error":
error_details = result.get("error", result)
error_code = error_details["code"]
error_type = error_details["type"]
error_message = error_details["message"]
return ErrorResponse(
error=ErrorInfo(
message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
type=error_type,
code=error_code,
),
).model_dump()
# 4. Combine choices and preserve original order.
# If there is only one result, we return it directly.
if len(results) == 1:
return results[0]
# Use the first result as base response. There are some fields that cannot be shared
# across sub-requests. For now it is just the usage field.
final_response = dict(results[0])
final_response["usage"] = aggregate_completion_usage_info(results, self.backend)
# Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
# vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
combined_choices: list[Dict[str, Any]] = [None] * num_prompts
for indices, result in zip(indices_list, results):
# indices are the original prompt indices that the task's response corresponds to
for local_idx, original_idx in enumerate(indices):
choice = result["choices"][local_idx]
choice["index"] = original_idx # overwrite index with the global position
combined_choices[original_idx] = choice
# sanity check that the index is correct
for new_idx in range(len(combined_choices)):
assert combined_choices[new_idx]["index"] == new_idx
final_response["choices"] = combined_choices
return final_response
async def wake_up(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("wake_up", *args, **kwargs)
async def sleep(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("sleep", *args, **kwargs)
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator on all engines.
Args:
init_info: WeightSyncInitInfo from the sender.
Note:
Per-engine adjustments (e.g., rank_offset for broadcast) are handled
by init_info.for_engine().
"""
tasks = []
for i, engine in enumerate(self.engines):
# With vLLM, DP ranks are managed as separate engine instances
# We want the index of truly separate vllm deployments i.e different dist worlds
engine_idx = i // engine.dp_size()
engine_init_info = init_info.for_engine(engine_idx, engine.tp_size(), engine.pp_size(), engine.dp_size())
tasks.append(engine.init_weight_update_communicator(engine_init_info))
await asyncio.gather(*tasks)
async def update_named_weights(self, request: WeightUpdateRequest):
return await self._run_on_all_engines("update_named_weights", request=request)
async def reset_prefix_cache(self):
return await self._run_on_all_engines("reset_prefix_cache")
async def teardown(self):
return await self._run_on_all_engines("teardown")
def tp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement tp_size()")
def pp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement pp_size()")
def dp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement dp_size()")
# ----------------------------
# Generation pause and resume
# ----------------------------
async def pause_generation(self) -> None:
"""
Pauses generation for all engines using vLLM's native keep mode.
In-flight requests are frozen (not aborted) and will resume from where they left off
when `resume_generation()` is called. New requests are blocked until resume.
"""
await self._run_on_all_engines("pause_generation")
async def resume_generation(self) -> None:
"""
Resumes generation for all engines after a keep-mode pause.
Frozen in-flight requests continue from where they left off, and new requests are unblocked.
"""
await self._run_on_all_engines("resume_generation")
# ----------------------------
# HTTP endpoint related methods
# ----------------------------
def __del__(self):
"""
Destructor to shut down the HTTP endpoint if it was started.
"""
# TODO(Charlie): __del__ is not guaranteed to be called in general. Add to `teardown` method
# when the `_handle_termination` flow is implemented. See `skyrl_train/workers/worker.py`
# comments on `_handle_termination` for more details.
if (
self.enable_http_endpoint
and hasattr(
self, "_server_thread"
) # don't want to shut down the server when it is pickled as a ray method argument.
and self._server_thread is not None
):
try:
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
shutdown_server,
)
shutdown_server(
host=self.http_endpoint_host,
port=self.http_endpoint_port,
max_wait_seconds=10,
)
if hasattr(self, "_server_thread") and self._server_thread.is_alive():
self._server_thread.join(timeout=10)
except Exception as e:
logger.error(f"Error shutting down HTTP endpoint: {e}")
def __getstate__(self):
"""
Override to avoid pickling the server thread, which is not picklable.
Needed when passing InferenceEngineClient as an argument to async_run_ray_method(), mainly for
invoking `init_weight_sync_state()` and `broadcast_to_inference_engines()`, which do
not need these attributes.
"""
state = self.__dict__.copy()
state["_server_thread"] = None
return state
def _spin_up_http_endpoint(self):
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client_http_endpoint import (
serve,
wait_for_server_ready,
)
self._server_thread = threading.Thread(
target=serve,
args=(self,),
kwargs={
"host": self.http_endpoint_host,
"port": self.http_endpoint_port,
"log_level": "warning",
},
daemon=True,
)
self._server_thread.start()
wait_for_server_ready(
host=self.http_endpoint_host,
port=self.http_endpoint_port,
max_wait_seconds=30,
)
logger.info(
f"InferenceEngineClient HTTP endpoint started on {self.http_endpoint_host}:{self.http_endpoint_port}"
)attr engines
engines = enginesattr tokenizer
tokenizer = tokenizerattr inference_engine_cfg
inference_engine_cfg = inference_engine_cfgattr model_name
model_name = served_model_nameattr backend
backend = inference_engine_cfg.backendattr enable_http_endpoint
enable_http_endpoint = inference_engine_cfg.enable_http_endpointattr http_endpoint_host
http_endpoint_host = inference_engine_cfg.http_endpoint_hostattr http_endpoint_port
http_endpoint_port = inference_engine_cfg.http_endpoint_portmethod async generate
generate(input_batch: InferenceEngineInput) -> InferenceEngineOutputSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:92-163
async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
# 0. Extract input
prompts = input_batch.get("prompts")
prompt_token_ids = input_batch.get("prompt_token_ids")
session_ids = input_batch.get("session_ids")
sampling_params = input_batch.get("sampling_params")
if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None):
raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.")
if prompt_token_ids is None:
prompt_token_ids = self.tokenizer.apply_chat_template(
prompts,
add_generation_prompt=True,
return_dict=False,
tokenize=True,
)
num_prompts = len(prompt_token_ids)
num_inference_engines = len(self.engines)
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_ids,
)
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
# index prompt_token_ids with prompt_ids
cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids]
engine_input = InferenceEngineInput(
prompt_token_ids=cur_prompt_token_ids,
sampling_params=sampling_params,
)
tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input)))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Reconstruct output in original order
n = len(prompt_token_ids)
responses: list[str] = [""] * n
stop_reasons: list[str] = [""] * n
response_logprobs: List[Optional[List[float]]] = [None for _ in range(n)]
response_ids: List[List[int]] = [[] for _ in range(n)]
rollout_expert_indices: List[Optional[List[List[List[int]]]]] = [None for _ in range(n)]
# a bit hacky for now
add_resp_logprobs = False
add_rollout_expert_indices = False
for indices, result in zip(indices_list, results):
for local_idx, original_idx in enumerate(indices):
responses[original_idx] = result["responses"][local_idx]
stop_reasons[original_idx] = result["stop_reasons"][local_idx]
response_ids[original_idx] = result["response_ids"][local_idx]
if result.get("response_logprobs", None):
add_resp_logprobs = True
response_logprobs[original_idx] = result["response_logprobs"][local_idx]
if result.get("rollout_expert_indices", None):
add_rollout_expert_indices = True
rollout_expert_indices[original_idx] = result["rollout_expert_indices"][local_idx]
return InferenceEngineOutput(
responses=responses,
stop_reasons=stop_reasons,
response_ids=response_ids,
response_logprobs=response_logprobs if add_resp_logprobs else None,
rollout_expert_indices=rollout_expert_indices if add_rollout_expert_indices else None,
)method abstractmethod sample
sample(prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any], session_id: Optional[Union[str, int]] = None) -> InferenceEngineOutputGenerate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics. Generates num_samples independent completions from the same prompt.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt_token_ids | List[int] | Token IDs for a single prompt (not batched). | required |
num_samples | int | Number of independent samples to generate. | required |
sampling_params | Dict[str, Any] | Sampling parameters (temperature, max_tokens, etc.). | required |
session_id | Optional[Union[str, int]] | Optional session ID for consistent engine routing (e.g., conversation ID). If None, uses random load-balancing. Tinker API should pass None since each sample() call is independent. | None |
Returns:
| Type | Description |
|---|---|
| InferenceEngineOutput | InferenceEngineOutput containing num_samples results. |
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:180-211
async def sample(
self,
prompt_token_ids: List[int],
num_samples: int,
sampling_params: Dict[str, Any],
session_id: Optional[Union[str, int]] = None,
) -> InferenceEngineOutput:
"""Generate multiple independent samples from a single prompt.
This method provides Tinker-compatible token-in/token-out sampling semantics.
Generates num_samples independent completions from the same prompt.
Args:
prompt_token_ids: Token IDs for a single prompt (not batched).
num_samples: Number of independent samples to generate.
sampling_params: Sampling parameters (temperature, max_tokens, etc.).
session_id: Optional session ID for consistent engine routing (e.g., conversation ID).
If None, uses random load-balancing. Tinker API should pass None since
each sample() call is independent.
Returns:
InferenceEngineOutput containing num_samples results.
"""
# Select engine (random if session_id is None, consistent hash otherwise)
engine_idx = self._select_engine_idx(session_id)
engine = self.engines[engine_idx]
return await engine.sample(
prompt_token_ids=prompt_token_ids,
num_samples=num_samples,
sampling_params=sampling_params,
)method abstractmethod async chat_completion
chat_completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:213-219
async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
session_id = request_payload["json"].pop("session_id", None)
if session_id is not None:
assert isinstance(session_id, (str, int)), "Session ID must be an integer or string for `/chat/completions`"
engine_idx = self._select_engine_idx(session_id)
return await self.engines[engine_idx].chat_completion(request_payload)method abstractmethod async completion
completion(request_payload: Dict[str, Any]) -> Dict[str, Any]Handles an OpenAI /completions request.
Since request["prompt"] can be Union[list[int], list[list[int]], str, list[str]],
(i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
differently, based on whether it's a single or batched request, and whether request["session_id"]
is provided. This is similar to generate() method.
For single, we do the same routing logic as chat_completion(). For batched, we route by
request["session_id"] if present, and if not we split evenly across engines.
Regardless, the order will be maintained, i.e. output["choices"][i] corresponds to request["prompt"][i].
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:221-316
async def completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an OpenAI /completions request.
Since `request["prompt"]` can be `Union[list[int], list[list[int]], str, list[str]]`,
(i.e. {batched, single} x {string, token IDs}), we need to route the request to engines
differently, based on whether it's a single or batched request, and whether `request["session_id"]`
is provided. This is similar to `generate()` method.
For single, we do the same routing logic as `chat_completion()`. For batched, we route by
`request["session_id"]` if present, and if not we split evenly across engines.
Regardless, the order will be maintained, i.e. `output["choices"][i]` corresponds to `request["prompt"][i]`.
"""
body = request_payload.get("json", {})
# NOTE(Charlie): do not reuse headers here as the single request may become various new requests
headers = {"Content-Type": "application/json"}
# 1. Postprocess prompt, session_id, and validate request.
prompt = body.get("prompt")
session_id_value = body.pop("session_id", None)
ret = postprocess_completion_request(prompt, session_id_value)
session_id_list: Optional[Union[List[int], List[str], ErrorResponse]] = ret[0]
prompt: Union[List[List[int]], List[str]] = ret[1]
if isinstance(session_id_list, ErrorResponse):
return session_id_list.model_dump()
num_prompts = len(prompt)
num_inference_engines = len(self.engines)
assert num_prompts > 0, "Number of prompts must be greater than 0"
# 1. Route prompts to engines
engine_idx_to_prompt_ids: dict[int, list[int]] = route_prompts_to_engines(
num_prompts=num_prompts,
num_inference_engines=num_inference_engines,
session_ids=session_id_list,
)
# 2. Generate responses concurrently
tasks: list[asyncio.Task] = []
indices_list: list[list[int]] = [] # the original prompt indices that each task works on
for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items():
cur_prompt = [prompt[i] for i in prompt_ids]
# reuse the exact same request except for the prompt
cur_json = dict(body)
cur_json["prompt"] = cur_prompt
coro = self.engines[engine_idx].completion({"json": cur_json, "headers": headers})
tasks.append(asyncio.create_task(coro))
indices_list.append(prompt_ids)
results = await asyncio.gather(*tasks)
# 3. Check for errors.
# results can be ErrorResponse or CompletionResponse. If one of the sub-requests fails, we
# return an error response. That is, there is no partial success, following vLLM's behavior.
for result in results:
if "error" in result or result.get("object", "") == "error":
error_details = result.get("error", result)
error_code = error_details["code"]
error_type = error_details["type"]
error_message = error_details["message"]
return ErrorResponse(
error=ErrorInfo(
message=f"In one of the engines that SkyRL manages, an error occurred: {error_message}",
type=error_type,
code=error_code,
),
).model_dump()
# 4. Combine choices and preserve original order.
# If there is only one result, we return it directly.
if len(results) == 1:
return results[0]
# Use the first result as base response. There are some fields that cannot be shared
# across sub-requests. For now it is just the usage field.
final_response = dict(results[0])
final_response["usage"] = aggregate_completion_usage_info(results, self.backend)
# Aggregate choices. TODO(Charlie): improve logic when we need to support n > 1
# vLLM sets index positions per sub-batch, so we reset indices to be 0..n-1 for the combined response.
combined_choices: list[Dict[str, Any]] = [None] * num_prompts
for indices, result in zip(indices_list, results):
# indices are the original prompt indices that the task's response corresponds to
for local_idx, original_idx in enumerate(indices):
choice = result["choices"][local_idx]
choice["index"] = original_idx # overwrite index with the global position
combined_choices[original_idx] = choice
# sanity check that the index is correct
for new_idx in range(len(combined_choices)):
assert combined_choices[new_idx]["index"] == new_idx
final_response["choices"] = combined_choices
return final_responsemethod abstractmethod async wake_up
wake_up(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:318-319
async def wake_up(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("wake_up", *args, **kwargs)method abstractmethod async sleep
sleep(*args: Any, **kwargs: Any)Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:321-322
async def sleep(self, *args: Any, **kwargs: Any):
return await self._run_on_all_engines("sleep", *args, **kwargs)method abstractmethod async init_weight_update_communicator
init_weight_update_communicator(init_info: 'WeightSyncInitInfo')Initialize weight update communicator on all engines.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_info | 'WeightSyncInitInfo' | WeightSyncInitInfo from the sender. | required |
Note:
Per-engine adjustments (e.g., rank_offset for broadcast) are handled by init_info.for_engine().
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:324-341
async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo"):
"""Initialize weight update communicator on all engines.
Args:
init_info: WeightSyncInitInfo from the sender.
Note:
Per-engine adjustments (e.g., rank_offset for broadcast) are handled
by init_info.for_engine().
"""
tasks = []
for i, engine in enumerate(self.engines):
# With vLLM, DP ranks are managed as separate engine instances
# We want the index of truly separate vllm deployments i.e different dist worlds
engine_idx = i // engine.dp_size()
engine_init_info = init_info.for_engine(engine_idx, engine.tp_size(), engine.pp_size(), engine.dp_size())
tasks.append(engine.init_weight_update_communicator(engine_init_info))
await asyncio.gather(*tasks)method abstractmethod async update_named_weights
update_named_weights(request: WeightUpdateRequest)Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:343-344
async def update_named_weights(self, request: WeightUpdateRequest):
return await self._run_on_all_engines("update_named_weights", request=request)method abstractmethod async reset_prefix_cache
reset_prefix_cache()Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:346-347
async def reset_prefix_cache(self):
return await self._run_on_all_engines("reset_prefix_cache")method abstractmethod async teardown
teardown()Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:349-350
async def teardown(self):
return await self._run_on_all_engines("teardown")method abstractmethod tp_size
tp_size() -> intSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:352-353
def tp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement tp_size()")method abstractmethod pp_size
pp_size() -> intSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:355-356
def pp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement pp_size()")method abstractmethod dp_size
dp_size() -> intSource code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:358-359
def dp_size(self) -> int:
raise NotImplementedError("InferenceEngineClient does not implement dp_size()")method abstractmethod async pause_generation
pause_generation() -> NonePauses generation for all engines using vLLM's native keep mode.
In-flight requests are frozen (not aborted) and will resume from where they left off
when resume_generation() is called. New requests are blocked until resume.
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:364-371
async def pause_generation(self) -> None:
"""
Pauses generation for all engines using vLLM's native keep mode.
In-flight requests are frozen (not aborted) and will resume from where they left off
when `resume_generation()` is called. New requests are blocked until resume.
"""
await self._run_on_all_engines("pause_generation")method abstractmethod async resume_generation
resume_generation() -> NoneResumes generation for all engines after a keep-mode pause.
Frozen in-flight requests continue from where they left off, and new requests are unblocked.
Source code in skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:373-379
async def resume_generation(self) -> None:
"""
Resumes generation for all engines after a keep-mode pause.
Frozen in-flight requests continue from where they left off, and new requests are unblocked.
"""
await self._run_on_all_engines("resume_generation")