Types
Request/response types for the Tinker API.
Enums
class RequestType
Bases: str, Enum
Types of requests that can be processed.
Attributes:
| Name | Type | Description |
|---|---|---|
CREATE_MODEL | ||
FORWARD_BACKWARD | ||
FORWARD | ||
OPTIM_STEP | ||
SAVE_WEIGHTS_FOR_SAMPLER | ||
SAVE_WEIGHTS | ||
LOAD_WEIGHTS | ||
SAMPLE | ||
UNLOAD_MODEL | ||
EXTERNAL |
Source code in skyrl/tinker/types.py:15-29
class RequestType(str, Enum):
"""Types of requests that can be processed."""
CREATE_MODEL = "create_model"
FORWARD_BACKWARD = "forward_backward"
FORWARD = "forward"
OPTIM_STEP = "optim_step"
SAVE_WEIGHTS_FOR_SAMPLER = "save_weights_for_sampler"
SAVE_WEIGHTS = "save_weights"
LOAD_WEIGHTS = "load_weights"
SAMPLE = "sample"
UNLOAD_MODEL = "unload_model"
# External request that should not be processed by the engine
EXTERNAL = "external"attr CREATE_MODEL
CREATE_MODEL = 'create_model'attr FORWARD_BACKWARD
FORWARD_BACKWARD = 'forward_backward'attr FORWARD
FORWARD = 'forward'attr OPTIM_STEP
OPTIM_STEP = 'optim_step'attr SAVE_WEIGHTS_FOR_SAMPLER
SAVE_WEIGHTS_FOR_SAMPLER = 'save_weights_for_sampler'attr SAVE_WEIGHTS
SAVE_WEIGHTS = 'save_weights'attr LOAD_WEIGHTS
LOAD_WEIGHTS = 'load_weights'attr SAMPLE
SAMPLE = 'sample'attr UNLOAD_MODEL
UNLOAD_MODEL = 'unload_model'attr EXTERNAL
EXTERNAL = 'external'class CheckpointType
Bases: str, Enum
Type of checkpoint.
Attributes:
Source code in skyrl/tinker/types.py:32-36
class CheckpointType(str, Enum):
"""Type of checkpoint."""
TRAINING = "training"
SAMPLER = "sampler"attr TRAINING
TRAINING = 'training'attr SAMPLER
SAMPLER = 'sampler'Configuration
class TinkerPath
Bases: BaseModel
Functions:
| Name | Description |
|---|---|
parse | Parse a URL string into a TinkerPath object. |
Attributes:
| Name | Type | Description |
|---|---|---|
primary_id | str | |
kind | str | |
secondary_id | str |
Source code in skyrl/tinker/types.py:39-55
class TinkerPath(BaseModel):
primary_id: str
kind: str
secondary_id: str
@classmethod
def parse(cls, url: str) -> TinkerPath | None:
"""Parse a URL string into a TinkerPath object."""
parsed = urlparse(url)
match (parsed.scheme, *parsed.path.split("/")):
case ("tinker", "", secondary_id):
return cls(primary_id=parsed.netloc, kind="", secondary_id=secondary_id)
case ("tinker", "", kind, secondary_id):
return cls(primary_id=parsed.netloc, kind=kind, secondary_id=secondary_id)
case _:
return Noneattr primary_id
primary_id: strattr kind
kind: strattr secondary_id
secondary_id: strmethod classmethod parse
parse(url: str) -> TinkerPath | NoneParse a URL string into a TinkerPath object.
Source code in skyrl/tinker/types.py:44-55
@classmethod
def parse(cls, url: str) -> TinkerPath | None:
"""Parse a URL string into a TinkerPath object."""
parsed = urlparse(url)
match (parsed.scheme, *parsed.path.split("/")):
case ("tinker", "", secondary_id):
return cls(primary_id=parsed.netloc, kind="", secondary_id=secondary_id)
case ("tinker", "", kind, secondary_id):
return cls(primary_id=parsed.netloc, kind=kind, secondary_id=secondary_id)
case _:
return Noneclass AdamParams
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
learning_rate | float | |
beta1 | float | |
beta2 | float | |
eps | float | |
weight_decay | float |
Source code in skyrl/tinker/types.py:58-63
class AdamParams(BaseModel):
learning_rate: float
beta1: float
beta2: float
eps: float
weight_decay: floatattr learning_rate
learning_rate: floatattr beta1
beta1: floatattr beta2
beta2: floatattr eps
eps: floatattr weight_decay
weight_decay: floatclass LoraConfig
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
rank | int | |
alpha | float | |
seed | int | |
train_attn | bool | |
train_mlp | bool | |
train_unembed | bool |
Source code in skyrl/tinker/types.py:66-72
class LoraConfig(BaseModel):
rank: int
alpha: float
seed: int
train_attn: bool = True
train_mlp: bool = True
train_unembed: bool = Falseattr rank
rank: intattr alpha
alpha: floatattr seed
seed: intattr train_attn
train_attn: bool = Trueattr train_mlp
train_mlp: bool = Trueattr train_unembed
train_unembed: bool = Falseclass SamplingParams
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
temperature | float | |
max_tokens | int | |
seed | int | |
stop_tokens | list[int] | None |
stop_strings | list[str] | None |
top_k | int | |
top_p | float |
Source code in skyrl/tinker/types.py:175-182
class SamplingParams(BaseModel):
temperature: float
max_tokens: int
seed: int
stop_tokens: list[int] | None = None
stop_strings: list[str] | None = None
top_k: int = -1 # -1 for no limit
top_p: float = 1.0 # 1.0 for no filteringattr temperature
temperature: floatattr max_tokens
max_tokens: intattr seed
seed: intattr stop_tokens
stop_tokens: list[int] | None = Noneattr stop_strings
stop_strings: list[str] | None = Noneattr top_k
top_k: int = -1attr top_p
top_p: float = 1.0class ModelMetadata
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
adapter_index | int | |
lora_config | LoraConfig | |
loaded_checkpoint_id | str | None |
Source code in skyrl/tinker/types.py:185-188
class ModelMetadata(BaseModel):
adapter_index: int
lora_config: LoraConfig
loaded_checkpoint_id: str | None = Noneattr adapter_index
adapter_index: intattr lora_config
lora_config: LoraConfigattr loaded_checkpoint_id
loaded_checkpoint_id: str | None = NoneModel Lifecycle
class CreateModelInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
lora_config | LoraConfig |
Source code in skyrl/tinker/types.py:75-76
class CreateModelInput(BaseModel):
lora_config: LoraConfigattr lora_config
lora_config: LoraConfigclass CreateModelOutput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
model_id | str | |
base_model | str | |
lora_config | LoraConfig |
Source code in skyrl/tinker/types.py:79-82
class CreateModelOutput(BaseModel):
model_id: str
base_model: str
lora_config: LoraConfigattr model_id
model_id: strattr base_model
base_model: strattr lora_config
lora_config: LoraConfigclass UnloadModelInput
Bases: BaseModel
class UnloadModelOutput
Bases: BaseModel
Attributes:
Source code in skyrl/tinker/types.py:89-92
class UnloadModelOutput(BaseModel):
model_id: str
status: str
type: str = "unload_model"attr model_id
model_id: strattr status
status: strattr type
type: str = 'unload_model'Forward / Backward
class ModelInputChunk
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
tokens | list[int] |
Source code in skyrl/tinker/types.py:95-96
class ModelInputChunk(BaseModel):
tokens: list[int]attr tokens
tokens: list[int]class ModelInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
chunks | list[ModelInputChunk] |
Source code in skyrl/tinker/types.py:99-100
class ModelInput(BaseModel):
chunks: list[ModelInputChunk]attr chunks
chunks: list[ModelInputChunk]class TensorData
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
data | list[int] | list[float] |
Source code in skyrl/tinker/types.py:103-104
class TensorData(BaseModel):
data: list[int] | list[float]attr data
data: list[int] | list[float]class LossFnInputs
Bases: BaseModel
Attributes:
Source code in skyrl/tinker/types.py:107-111
class LossFnInputs(BaseModel):
target_tokens: TensorData
weights: TensorData
advantages: TensorData
logprobs: TensorDataattr target_tokens
target_tokens: TensorDataattr weights
weights: TensorDataattr advantages
advantages: TensorDataattr logprobs
logprobs: TensorDataclass Datum
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
loss_fn_inputs | LossFnInputs | |
model_input | ModelInput |
Source code in skyrl/tinker/types.py:114-116
class Datum(BaseModel):
loss_fn_inputs: LossFnInputs
model_input: ModelInputattr loss_fn_inputs
loss_fn_inputs: LossFnInputsattr model_input
model_input: ModelInputclass ForwardBackwardInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
data | list[Datum] | |
loss_fn | Literal['cross_entropy', 'importance_sampling', 'ppo', 'cispo'] | |
loss_fn_config | dict[str, float] | None |
Source code in skyrl/tinker/types.py:119-122
class ForwardBackwardInput(BaseModel):
data: list[Datum]
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo"]
loss_fn_config: dict[str, float] | None = Noneattr data
data: list[Datum]attr loss_fn
loss_fn: Literal['cross_entropy', 'importance_sampling', 'ppo', 'cispo']attr loss_fn_config
loss_fn_config: dict[str, float] | None = Noneclass ForwardBackwardOutput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
loss_fn_output_type | str | |
loss_fn_outputs | list[dict] | |
metrics | dict |
Source code in skyrl/tinker/types.py:125-128
class ForwardBackwardOutput(BaseModel):
loss_fn_output_type: str
loss_fn_outputs: list[dict]
metrics: dictattr loss_fn_output_type
loss_fn_output_type: strattr loss_fn_outputs
loss_fn_outputs: list[dict]attr property metrics
metrics: dictOptimization
class OptimStepInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
adam_params | AdamParams |
Source code in skyrl/tinker/types.py:136-137
class OptimStepInput(BaseModel):
adam_params: AdamParamsattr adam_params
adam_params: AdamParamsclass OptimStepOutput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
metrics | dict[str, float] | None |
Source code in skyrl/tinker/types.py:140-141
class OptimStepOutput(BaseModel):
metrics: dict[str, float] | None = Noneattr property metrics
metrics: dict[str, float] | None = NoneCheckpointing
class SaveWeightsForSamplerInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
path | str | None |
sampling_session_seq_id | int | None |
seq_id | int | None |
sampling_session_id | str | None |
Source code in skyrl/tinker/types.py:144-148
class SaveWeightsForSamplerInput(BaseModel):
path: str | None = None
sampling_session_seq_id: int | None = None
seq_id: int | None = None
sampling_session_id: str | None = Noneattr path
path: str | None = Noneattr sampling_session_seq_id
sampling_session_seq_id: int | None = Noneattr seq_id
seq_id: int | None = Noneattr sampling_session_id
sampling_session_id: str | None = Noneclass SaveWeightsForSamplerOutput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
path | str | None |
type | str | |
sampling_session_id | str | None |
Source code in skyrl/tinker/types.py:151-154
class SaveWeightsForSamplerOutput(BaseModel):
path: str | None = None
type: str
sampling_session_id: str | None = Noneattr path
path: str | None = Noneattr type
type: strattr sampling_session_id
sampling_session_id: str | None = Noneclass SaveWeightsInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
path | str |
Source code in skyrl/tinker/types.py:157-158
class SaveWeightsInput(BaseModel):
path: strattr path
path: strclass SaveWeightsOutput
Bases: BaseModel
Attributes:
Source code in skyrl/tinker/types.py:161-163
class SaveWeightsOutput(BaseModel):
path: str
type: strattr path
path: strattr type
type: strclass LoadWeightsInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
source_model_id | str | |
checkpoint_id | str |
Source code in skyrl/tinker/types.py:166-168
class LoadWeightsInput(BaseModel):
source_model_id: str
checkpoint_id: strattr source_model_id
source_model_id: strattr checkpoint_id
checkpoint_id: strclass LoadWeightsOutput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
type | str |
Source code in skyrl/tinker/types.py:171-172
class LoadWeightsOutput(BaseModel):
type: strattr type
type: strSampling
class SampleInput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
base_model | str | None |
prompt | ModelInput | |
sampling_params | SamplingParams | |
num_samples | int | |
checkpoint_id | str | |
prompt_logprobs | bool |
Source code in skyrl/tinker/types.py:191-197
class SampleInput(BaseModel):
base_model: str | None = None
prompt: ModelInput
sampling_params: SamplingParams
num_samples: int
checkpoint_id: str
prompt_logprobs: boolattr base_model
base_model: str | None = Noneattr prompt
prompt: ModelInputattr sampling_params
sampling_params: SamplingParamsattr num_samples
num_samples: intattr checkpoint_id
checkpoint_id: strattr prompt_logprobs
prompt_logprobs: boolclass GeneratedSequence
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
stop_reason | Literal['length', 'stop'] | |
tokens | list[int] | |
logprobs | list[float] |
Source code in skyrl/tinker/types.py:200-203
class GeneratedSequence(BaseModel):
stop_reason: Literal["length", "stop"]
tokens: list[int]
logprobs: list[float]attr stop_reason
stop_reason: Literal['length', 'stop']attr tokens
tokens: list[int]attr logprobs
logprobs: list[float]class SampleOutput
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
sequences | list[GeneratedSequence] | |
prompt_logprobs | list[float] | None |
Source code in skyrl/tinker/types.py:206-208
class SampleOutput(BaseModel):
sequences: list[GeneratedSequence]
prompt_logprobs: list[float] | None = Noneattr sequences
sequences: list[GeneratedSequence]attr prompt_logprobs
prompt_logprobs: list[float] | None = NoneEngine Types
class ErrorResponse
Bases: BaseModel
Attributes:
Source code in skyrl/tinker/types.py:131-133
class ErrorResponse(BaseModel):
error: str
status: strattr error
error: strattr status
status: strclass EngineMetrics
Bases: BaseModel
Attributes:
| Name | Type | Description |
|---|---|---|
train_seq_len_jit_times | dict[int, float] | |
sample_seq_len_jit_times | dict[int, float] |
Source code in skyrl/tinker/types.py:212-214
class EngineMetrics(BaseModel):
train_seq_len_jit_times: dict[int, float] = {}
sample_seq_len_jit_times: dict[int, float] = {}attr train_seq_len_jit_times
train_seq_len_jit_times: dict[int, float] = {}attr sample_seq_len_jit_times
sample_seq_len_jit_times: dict[int, float] = {}class PreparedModelPassBatch
Bases: BaseModel
Prepared batch data for forward/forward_backward operations.
Engine extracts this from requests, backend converts to JAX arrays and computes.
Attributes:
| Name | Type | Description |
|---|---|---|
all_input_ids | list[list[int]] | |
all_targets | list[list[int]] | |
all_token_weights | list[list[float]] | |
all_sampling_logprobs | list[list[float]] | |
all_advantages | list[list[float]] | |
all_model_ids | list[str] | |
all_loss_fns | list[str] | |
all_loss_fn_configs | list[dict[str, float] | None] |
request_batch_slices | list[tuple[str, str, int, int]] |
Source code in skyrl/tinker/types.py:221-240
class PreparedModelPassBatch(BaseModel):
"""Prepared batch data for forward/forward_backward operations.
Engine extracts this from requests, backend converts to JAX arrays and computes.
"""
# Per-example data (list of lists)
all_input_ids: list[list[int]]
all_targets: list[list[int]]
all_token_weights: list[list[float]]
all_sampling_logprobs: list[list[float]]
all_advantages: list[list[float]]
# Per-example scalars
all_model_ids: list[str]
all_loss_fns: list[str]
all_loss_fn_configs: list[dict[str, float] | None]
# Mapping from examples back to requests: (request_id, model_id, start_idx, end_idx)
request_batch_slices: list[tuple[str, str, int, int]]attr all_input_ids
all_input_ids: list[list[int]]attr all_targets
all_targets: list[list[int]]attr all_token_weights
all_token_weights: list[list[float]]attr all_sampling_logprobs
all_sampling_logprobs: list[list[float]]attr all_advantages
all_advantages: list[list[float]]attr all_model_ids
all_model_ids: list[str]attr all_loss_fns
all_loss_fns: list[str]attr all_loss_fn_configs
all_loss_fn_configs: list[dict[str, float] | None]attr request_batch_slices
request_batch_slices: list[tuple[str, str, int, int]]class PreparedSampleBatch
Bases: BaseModel
Prepared batch data for sample operations.
Engine extracts this from requests, backend converts to JAX arrays and computes.
Attributes:
| Name | Type | Description |
|---|---|---|
all_prompts | list[list[int]] | |
all_sampling_params | list[SamplingParams] | |
all_model_ids | list[str] | |
all_checkpoint_ids | list[str] | |
all_checkpoint_paths | list[str] | |
needs_prompt_logprobs | bool | |
request_batch_slices | list[tuple[str, str, int, int, bool]] |
Source code in skyrl/tinker/types.py:243-260
class PreparedSampleBatch(BaseModel):
"""Prepared batch data for sample operations.
Engine extracts this from requests, backend converts to JAX arrays and computes.
"""
# Per-sample data
all_prompts: list[list[int]]
all_sampling_params: list[SamplingParams]
all_model_ids: list[str]
all_checkpoint_ids: list[str]
all_checkpoint_paths: list[str]
# Whether any request needs prompt logprobs
needs_prompt_logprobs: bool
# Mapping from samples back to requests: (request_id, model_id, start_idx, end_idx, prompt_logprobs_requested)
request_batch_slices: list[tuple[str, str, int, int, bool]]attr all_prompts
all_prompts: list[list[int]]attr all_sampling_params
all_sampling_params: list[SamplingParams]attr all_model_ids
all_model_ids: list[str]attr all_checkpoint_ids
all_checkpoint_ids: list[str]attr all_checkpoint_paths
all_checkpoint_paths: list[str]attr needs_prompt_logprobs
needs_prompt_logprobs: boolattr request_batch_slices
request_batch_slices: list[tuple[str, str, int, int, bool]]