Data Interface
Data Interface — TensorBatch, TrainingInput, GeneratorInput/Output.
Trainer APIs
class TensorBatch
TensorBatch(*args, **kwargs)Bases: dict, Generic[DictType]
Base class for training batches
This defines a generic container for a batch of training data (inputs or outputs). Consists of a dictionary of tensors along with some metadata.
Functions:
| Name | Description |
|---|---|
select | Select a subset of the data batch. |
to | Move tensors to device and/or cast to dtype. |
contiguous | Make the tensors contiguous |
repeat | Repeat entries in the data batch a specified number of times. |
repeat_interleave | Repeat entries in the data batch a specified number of times. |
chunk | Split into smaller chunks |
slice | Slice the data batch. |
save | Save the data to a pickle file |
load | Load the data from a pickle file |
cat | Concatenate shards. |
Attributes:
| Name | Type | Description |
|---|---|---|
metadata | Optional[Dict[str, Any]] | |
batch_size | int | Batch size for the tensors |
device | device | Get the device for the tensors |
Source code in skyrl/backends/skyrl_train/training_batch.py:15-355
class TensorBatch(dict, Generic[DictType]):
"""Base class for training batches
This defines a generic container for a batch of training data (inputs or outputs).
Consists of a dictionary of tensors along with some metadata.
"""
metadata: Optional[Dict[str, Any]] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._batch_size = None
self._device = None
self._check_consistency()
def select(self, keys: List[str], metadata_keys: Optional[List[str]] = None) -> "TensorBatch[DictType]":
"""Select a subset of the data batch.
Args:
keys: The keys to select
metadata_keys: The metadata keys to select
Returns:
A new `TensorBatch` object with the selected keys and metadata
"""
selected_batch_data = {}
for key in keys:
selected_batch_data[key] = self[key]
selected_metadata = {}
if metadata_keys is None:
selected_metadata = self.metadata
else:
selected_metadata = {}
for key in metadata_keys:
selected_metadata[key] = self.metadata[key]
new_batch = self.__class__(selected_batch_data)
new_batch.metadata = selected_metadata
return new_batch
def _check_consistency(self):
"""Check consistency of all present fields"""
keys = list(self.keys())
if len(keys) == 0:
return
batch_size = len(self[keys[0]])
self._batch_size = batch_size
for key in keys:
value = self[key]
if value is None:
continue
self._device = value.device if self._device is None else self._device
if not isinstance(value, torch.Tensor):
raise ValueError(f"Field {key} must be a tensor, got {type(value)}")
if len(value) != batch_size:
raise ValueError(f"Batch size mismatch in {key}")
if value.device != self._device:
raise ValueError(f"Device mismatch in {key}. Expected {self._device}, got {value.device}")
def __getitem__(self, index) -> "TensorBatch[DictType]":
if isinstance(index, slice):
return self.slice(index.start, index.stop, index.step)
elif isinstance(index, int):
return self.slice(index, index + 1)
else:
return super().__getitem__(index)
def __setitem__(self, key: str, value: Optional[torch.Tensor]) -> None:
if value is None:
super().__setitem__(key, value)
return
if not isinstance(value, torch.Tensor):
raise ValueError(f"Field {key} must be a tensor, got {type(value)}")
if hasattr(self, "_batch_size") and self._batch_size is not None and len(value) != self._batch_size:
raise ValueError(
f"Batch size mismatch in {key}. Expected tensor to be of size {self._batch_size}, got {len(value)}."
)
super().__setitem__(key, value)
if hasattr(self, "_batch_size") and self._batch_size is None:
self._batch_size = len(value)
def to(
self, device: torch.device = None, dtype: torch.dtype = None, *, non_blocking: bool = False
) -> "TensorBatch":
"""Move tensors to device and/or cast to dtype.
Args:
device: The device to move the tensors to
dtype: The dtype to cast the tensors to
non_blocking: Whether the operation should be non-blocking
"""
for key, value in self.items():
if value is None:
continue
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
self[key] = value.to(device, dtype, non_blocking=non_blocking)
return self
def contiguous(self) -> "TensorBatch":
"""Make the tensors contiguous"""
for key, value in self.items():
if value is None:
continue
# some of these asserts are not needed, but it's kept for type safety
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
self[key] = value.contiguous()
return self
@property
def batch_size(self) -> int:
"""Batch size for the tensors"""
return self._batch_size
@property
def device(self) -> torch.device:
"""Get the device for the tensors"""
return self._device
def __getstate__(self):
"""Serialize the `TensorBatch` object for pickle protocol.
Uses fast numpy-based serialization when possible, with fallback to torch.save
for dtypes not supported by numpy (e.g., bfloat16).
"""
self.contiguous()
if self._device is not None:
assert self._device == torch.device("cpu"), "Tensors must be on CPU before serialization"
batch_dict = {}
for key, value in self.items():
if value is None:
batch_dict[key] = None
else:
try:
# Fast path: direct memory copy via numpy (works for most dtypes)
arr = value.numpy()
batch_dict[key] = {
"format": "numpy",
"data": arr.tobytes(),
"shape": arr.shape,
"dtype": str(arr.dtype),
}
except TypeError:
# Fallback for dtypes not supported by numpy (e.g., bfloat16)
buffer = io.BytesIO()
torch.save(value, buffer)
batch_dict[key] = {
"format": "torch",
"data": buffer.getvalue(),
}
return {
"batch_dict": batch_dict,
"batch_size": self._batch_size,
"device": self._device,
"metadata": self.metadata,
}
def __setstate__(self, state):
"""Deserialize the `TensorBatch` object and load it into memory.
Handles both numpy-based format (fast path) and torch format (fallback for bfloat16 etc).
"""
for key, value in state["batch_dict"].items():
if value is None:
self[key] = None
elif value.get("format") == "torch":
# Fallback path: torch.load for unsupported dtypes
buffer = io.BytesIO(value["data"])
self[key] = torch.load(buffer, weights_only=True)
else:
# Fast path: reconstruct from numpy bytes
# Also handles legacy format without "format" key
arr = np.frombuffer(value["data"], dtype=np.dtype(value["dtype"]))
arr = arr.reshape(value["shape"])
# Convert to tensor (makes a copy, which is needed since frombuffer is read-only)
self[key] = torch.from_numpy(arr.copy())
self._batch_size = state["batch_size"]
self._device = state["device"]
self.metadata = state["metadata"]
self._check_consistency()
return self
def repeat(self, repeats: int):
"""Repeat entries in the data batch a specified number of times.
This is similar to `torch.repeat` (and `numpy.tile`). `metadata` is not repeated.
Args:
repeats: The number of times to repeat the data batch
Returns:
A new `TensorBatch` object with the data repeated
"""
new_batch = {}
for key, value in self.items():
if value is None:
new_batch[key] = value
else:
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
new_batch[key] = value.repeat(repeats)
new_batch = self.__class__(new_batch)
new_batch.metadata = self.metadata
return new_batch
def repeat_interleave(self, repeats: int):
"""Repeat entries in the data batch a specified number of times.
This is similar to `torch.repeat_interleave` (and `numpy.repeat`). `metadata` is not repeated.
Args:
repeats: The number of times to repeat the data batch
Returns:
A new `TensorBatch` object with the data repeated
"""
new_batch = {}
for key, value in self.items():
if value is None:
new_batch[key] = value
else:
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
new_batch[key] = value.repeat_interleave(repeats)
new_batch = self.__class__(new_batch)
new_batch.metadata = self.metadata
return new_batch
def chunk(self, chunk_size: int) -> List["TensorBatch[DictType]"]:
"""Split into smaller chunks"""
chunks = []
for i in range(0, self.batch_size, chunk_size):
chunk_data = {}
for key, value in self.items():
if value is not None:
if isinstance(value, torch.Tensor):
chunk_data[key] = value[i : i + chunk_size]
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
# `None` values are not chunked
chunk_data[key] = value
chunk = self.__class__(chunk_data)
chunk.metadata = self.metadata
chunks.append(chunk)
return chunks
def slice(self, start: int, end: int, step: int = 1) -> "TensorBatch[DictType]":
"""Slice the data batch.
Args:
start: The start index
end: The end index
step: The step size
Returns:
A new `TensorBatch` object with the view of the specified slice.
"""
slice_obj = slice(start, end, step)
sliced_data = {}
for key, value in self.items():
if value is not None:
if isinstance(value, torch.Tensor):
sliced_data[key] = value[slice_obj]
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
# `None` values are not sliced
sliced_data[key] = value
sliced_batch = self.__class__(sliced_data)
sliced_batch.metadata = self.metadata
return sliced_batch
def save(self, path: str):
"""Save the data to a pickle file"""
with open(path, "wb") as f:
pickle.dump(self, f)
def load(self, path: str):
"""Load the data from a pickle file"""
with open(path, "rb") as f:
return pickle.load(f)
@classmethod
def cat(cls, shards: List["TensorBatch[DictType]"]) -> "TensorBatch[DictType]":
"""Concatenate shards.
Args:
shards: The list of `TensorBatch` objects to cat
Returns:
A new `TensorBatch` object with the concatenated data
"""
cat_data = {}
assert len(shards) > 0, "Cannot cat an empty list of shards"
for key, value in shards[0].items():
if value is not None:
if isinstance(value, torch.Tensor):
cat_data[key] = torch.cat([shard[key] for shard in shards])
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
# `None` values are not cat'd
cat_data[key] = value
metadata = shards[0].metadata
cat_batch = cls(cat_data)
cat_batch.metadata = metadata
return cat_batch
def __len__(self) -> int:
"""Length of the batch.
Note that this is the same as the batch size rather than the number of keys in the batch.
"""
return self._batch_size
def __eq__(self, other: Any) -> bool:
"""Check if two `TensorBatch` objects are equal"""
if not isinstance(other, TensorBatch):
return False
if self.metadata != other.metadata:
return False
if len(self) != len(other):
return False
if len(self.items()) != len(other.items()):
return False
for k, v in self.items():
if k not in other or not torch.equal(v, other[k]):
return False
return True
def __str__(self) -> str:
"""String representation of the `TensorBatch` object"""
return f"TensorBatch(batch_size={self.batch_size}, device={self.device}, metadata={self.metadata}), items={self.items()}"
def __repr__(self) -> str:
"""String representation of the `TensorBatch` object"""
return self.__str__()attr metadata
metadata: Optional[Dict[str, Any]] = Nonemethod select
select(keys: List[str], metadata_keys: Optional[List[str]] = None) -> TensorBatch[DictType]Select a subset of the data batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
keys | List[str] | The keys to select | required |
metadata_keys | Optional[List[str]] | The metadata keys to select | None |
Returns:
| Type | Description |
|---|---|
| TensorBatch[DictType] | A new TensorBatch object with the selected keys and metadata |
Source code in skyrl/backends/skyrl_train/training_batch.py:30-52
def select(self, keys: List[str], metadata_keys: Optional[List[str]] = None) -> "TensorBatch[DictType]":
"""Select a subset of the data batch.
Args:
keys: The keys to select
metadata_keys: The metadata keys to select
Returns:
A new `TensorBatch` object with the selected keys and metadata
"""
selected_batch_data = {}
for key in keys:
selected_batch_data[key] = self[key]
selected_metadata = {}
if metadata_keys is None:
selected_metadata = self.metadata
else:
selected_metadata = {}
for key in metadata_keys:
selected_metadata[key] = self.metadata[key]
new_batch = self.__class__(selected_batch_data)
new_batch.metadata = selected_metadata
return new_batchmethod to
to(device: torch.device = None, dtype: torch.dtype = None, *, non_blocking: bool = False) -> TensorBatchMove tensors to device and/or cast to dtype.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device | device | The device to move the tensors to | None |
dtype | dtype | The dtype to cast the tensors to | None |
non_blocking | bool | Whether the operation should be non-blocking | False |
Source code in skyrl/backends/skyrl_train/training_batch.py:100-115
def to(
self, device: torch.device = None, dtype: torch.dtype = None, *, non_blocking: bool = False
) -> "TensorBatch":
"""Move tensors to device and/or cast to dtype.
Args:
device: The device to move the tensors to
dtype: The dtype to cast the tensors to
non_blocking: Whether the operation should be non-blocking
"""
for key, value in self.items():
if value is None:
continue
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
self[key] = value.to(device, dtype, non_blocking=non_blocking)
return selfmethod contiguous
contiguous() -> TensorBatchMake the tensors contiguous
Source code in skyrl/backends/skyrl_train/training_batch.py:117-125
def contiguous(self) -> "TensorBatch":
"""Make the tensors contiguous"""
for key, value in self.items():
if value is None:
continue
# some of these asserts are not needed, but it's kept for type safety
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
self[key] = value.contiguous()
return selfattr property batch_size
batch_size: intBatch size for the tensors
attr property device
device: torch.deviceGet the device for the tensors
method repeat
repeat(repeats: int)Repeat entries in the data batch a specified number of times.
This is similar to torch.repeat (and numpy.tile). metadata is not repeated.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repeats | int | The number of times to repeat the data batch | required |
Returns:
| Type | Description |
|---|---|
A new TensorBatch object with the data repeated |
Source code in skyrl/backends/skyrl_train/training_batch.py:202-222
def repeat(self, repeats: int):
"""Repeat entries in the data batch a specified number of times.
This is similar to `torch.repeat` (and `numpy.tile`). `metadata` is not repeated.
Args:
repeats: The number of times to repeat the data batch
Returns:
A new `TensorBatch` object with the data repeated
"""
new_batch = {}
for key, value in self.items():
if value is None:
new_batch[key] = value
else:
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
new_batch[key] = value.repeat(repeats)
new_batch = self.__class__(new_batch)
new_batch.metadata = self.metadata
return new_batchmethod repeat_interleave
repeat_interleave(repeats: int)Repeat entries in the data batch a specified number of times.
This is similar to torch.repeat_interleave (and numpy.repeat). metadata is not repeated.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repeats | int | The number of times to repeat the data batch | required |
Returns:
| Type | Description |
|---|---|
A new TensorBatch object with the data repeated |
Source code in skyrl/backends/skyrl_train/training_batch.py:224-244
def repeat_interleave(self, repeats: int):
"""Repeat entries in the data batch a specified number of times.
This is similar to `torch.repeat_interleave` (and `numpy.repeat`). `metadata` is not repeated.
Args:
repeats: The number of times to repeat the data batch
Returns:
A new `TensorBatch` object with the data repeated
"""
new_batch = {}
for key, value in self.items():
if value is None:
new_batch[key] = value
else:
assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
new_batch[key] = value.repeat_interleave(repeats)
new_batch = self.__class__(new_batch)
new_batch.metadata = self.metadata
return new_batchmethod chunk
chunk(chunk_size: int) -> List[TensorBatch[DictType]]Split into smaller chunks
Source code in skyrl/backends/skyrl_train/training_batch.py:246-263
def chunk(self, chunk_size: int) -> List["TensorBatch[DictType]"]:
"""Split into smaller chunks"""
chunks = []
for i in range(0, self.batch_size, chunk_size):
chunk_data = {}
for key, value in self.items():
if value is not None:
if isinstance(value, torch.Tensor):
chunk_data[key] = value[i : i + chunk_size]
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
# `None` values are not chunked
chunk_data[key] = value
chunk = self.__class__(chunk_data)
chunk.metadata = self.metadata
chunks.append(chunk)
return chunksmethod slice
slice(start: int, end: int, step: int = 1) -> TensorBatch[DictType]Slice the data batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start | int | The start index | required |
end | int | The end index | required |
step | int | The step size | 1 |
Returns:
| Type | Description |
|---|---|
| TensorBatch[DictType] | A new TensorBatch object with the view of the specified slice. |
Source code in skyrl/backends/skyrl_train/training_batch.py:265-289
def slice(self, start: int, end: int, step: int = 1) -> "TensorBatch[DictType]":
"""Slice the data batch.
Args:
start: The start index
end: The end index
step: The step size
Returns:
A new `TensorBatch` object with the view of the specified slice.
"""
slice_obj = slice(start, end, step)
sliced_data = {}
for key, value in self.items():
if value is not None:
if isinstance(value, torch.Tensor):
sliced_data[key] = value[slice_obj]
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
# `None` values are not sliced
sliced_data[key] = value
sliced_batch = self.__class__(sliced_data)
sliced_batch.metadata = self.metadata
return sliced_batchmethod save
save(path: str)Save the data to a pickle file
Source code in skyrl/backends/skyrl_train/training_batch.py:291-294
def save(self, path: str):
"""Save the data to a pickle file"""
with open(path, "wb") as f:
pickle.dump(self, f)method load
load(path: str)Load the data from a pickle file
Source code in skyrl/backends/skyrl_train/training_batch.py:296-299
def load(self, path: str):
"""Load the data from a pickle file"""
with open(path, "rb") as f:
return pickle.load(f)method classmethod cat
cat(shards: List[TensorBatch[DictType]]) -> TensorBatch[DictType]Concatenate shards.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shards | List[TensorBatch[DictType]] | The list of TensorBatch objects to cat | required |
Returns:
| Type | Description |
|---|---|
| TensorBatch[DictType] | A new TensorBatch object with the concatenated data |
Source code in skyrl/backends/skyrl_train/training_batch.py:301-325
@classmethod
def cat(cls, shards: List["TensorBatch[DictType]"]) -> "TensorBatch[DictType]":
"""Concatenate shards.
Args:
shards: The list of `TensorBatch` objects to cat
Returns:
A new `TensorBatch` object with the concatenated data
"""
cat_data = {}
assert len(shards) > 0, "Cannot cat an empty list of shards"
for key, value in shards[0].items():
if value is not None:
if isinstance(value, torch.Tensor):
cat_data[key] = torch.cat([shard[key] for shard in shards])
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
# `None` values are not cat'd
cat_data[key] = value
metadata = shards[0].metadata
cat_batch = cls(cat_data)
cat_batch.metadata = metadata
return cat_batchclass TrainingInput
Bases: TypedDict
Schema for training input batch
Attributes:
| Name | Type | Description |
|---|---|---|
sequences | Integer[Tensor, 'batch_size seq_len'] | |
attention_mask | Integer[Tensor, 'batch_size seq_len'] | |
loss_mask | Integer[Tensor, 'batch_size seq_len'] | |
response_mask | Integer[Tensor, 'batch_size seq_len'] | |
action_log_probs | Float[Tensor, 'batch_size seq_len'] | |
base_action_log_probs | Float[Tensor, 'batch_size seq_len'] | |
values | Optional[Float[Tensor, 'batch_size seq_len']] | |
returns | Float[Tensor, 'batch_size seq_len'] | |
advantages | Float[Tensor, 'batch_size seq_len'] | |
kl | Float[Tensor, 'batch_size seq_len'] | |
rewards | Optional[Float[Tensor, 'batch_size seq_len']] | |
rollout_logprobs | Optional[Float[Tensor, 'batch_size seq_len']] |
Source code in skyrl/backends/skyrl_train/training_batch.py:358-372
class TrainingInput(TypedDict, total=False):
"""Schema for training input batch"""
sequences: Integer[torch.Tensor, "batch_size seq_len"]
attention_mask: Integer[torch.Tensor, "batch_size seq_len"]
loss_mask: Integer[torch.Tensor, "batch_size seq_len"]
response_mask: Integer[torch.Tensor, "batch_size seq_len"]
action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
values: Optional[Float[torch.Tensor, "batch_size seq_len"]]
returns: Float[torch.Tensor, "batch_size seq_len"]
advantages: Float[torch.Tensor, "batch_size seq_len"]
kl: Float[torch.Tensor, "batch_size seq_len"]
rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]]
rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]]attr sequences
sequences: Integer[torch.Tensor, 'batch_size seq_len']attr attention_mask
attention_mask: Integer[torch.Tensor, 'batch_size seq_len']attr loss_mask
loss_mask: Integer[torch.Tensor, 'batch_size seq_len']attr response_mask
response_mask: Integer[torch.Tensor, 'batch_size seq_len']attr action_log_probs
action_log_probs: Float[torch.Tensor, 'batch_size seq_len']attr base_action_log_probs
base_action_log_probs: Float[torch.Tensor, 'batch_size seq_len']attr values
values: Optional[Float[torch.Tensor, 'batch_size seq_len']]attr returns
returns: Float[torch.Tensor, 'batch_size seq_len']attr advantages
advantages: Float[torch.Tensor, 'batch_size seq_len']attr kl
kl: Float[torch.Tensor, 'batch_size seq_len']attr rewards
rewards: Optional[Float[torch.Tensor, 'batch_size seq_len']]attr rollout_logprobs
rollout_logprobs: Optional[Float[torch.Tensor, 'batch_size seq_len']]class TrainingInputBatch
Bases: TensorBatch[TrainingInput]
Training input data
Functions:
| Name | Description |
|---|---|
select | Select a subset of the data batch. |
to | Move tensors to device and/or cast to dtype. |
contiguous | Make the tensors contiguous |
repeat | Repeat entries in the data batch a specified number of times. |
repeat_interleave | Repeat entries in the data batch a specified number of times. |
chunk | Split into smaller chunks |
slice | Slice the data batch. |
save | Save the data to a pickle file |
load | Load the data from a pickle file |
cat | Concatenate shards. |
Attributes:
| Name | Type | Description |
|---|---|---|
metadata | Optional[Dict[str, Any]] | |
batch_size | int | Batch size for the tensors |
device | device | Get the device for the tensors |
class TrainingOutputBatch
Bases: TensorBatch[Dict[str, Tensor]]
Training output data
Functions:
| Name | Description |
|---|---|
select | Select a subset of the data batch. |
to | Move tensors to device and/or cast to dtype. |
contiguous | Make the tensors contiguous |
repeat | Repeat entries in the data batch a specified number of times. |
repeat_interleave | Repeat entries in the data batch a specified number of times. |
chunk | Split into smaller chunks |
slice | Slice the data batch. |
save | Save the data to a pickle file |
load | Load the data from a pickle file |
cat | Concatenate shards. |
Attributes:
| Name | Type | Description |
|---|---|---|
metadata | Optional[Dict[str, Any]] | |
batch_size | int | Batch size for the tensors |
device | device | Get the device for the tensors |
Generator APIs
class GeneratorInput
Bases: TypedDict
Attributes:
| Name | Type | Description |
|---|---|---|
prompts | List[ConversationType] | |
env_classes | List[str] | |
env_extras | Optional[List[Dict[str, Any]]] | |
sampling_params | Optional[Dict[str, Any]] | |
trajectory_ids | Optional[List[TrajectoryID]] | |
batch_metadata | Optional[BatchMetadata] |
Source code in skyrl/train/generators/base.py:25-31
class GeneratorInput(TypedDict):
prompts: List[ConversationType]
env_classes: List[str]
env_extras: Optional[List[Dict[str, Any]]]
sampling_params: Optional[Dict[str, Any]]
trajectory_ids: Optional[List[TrajectoryID]]
batch_metadata: Optional[BatchMetadata]attr prompts
prompts: List[ConversationType]attr env_classes
env_classes: List[str]attr env_extras
env_extras: Optional[List[Dict[str, Any]]]attr sampling_params
sampling_params: Optional[Dict[str, Any]]attr trajectory_ids
trajectory_ids: Optional[List[TrajectoryID]]attr batch_metadata
batch_metadata: Optional[BatchMetadata]class GeneratorOutput
Bases: TypedDict
Attributes:
| Name | Type | Description |
|---|---|---|
prompt_token_ids | List[List[int]] | |
response_ids | List[List[int]] | |
rewards | Union[List[float], List[List[float]]] | |
loss_masks | List[List[int]] | |
stop_reasons | Optional[List[str]] | |
rollout_metrics | Optional[Dict[str, Any]] | |
rollout_logprobs | Optional[List[List[float]]] | |
trajectory_ids | Optional[List[TrajectoryID]] | |
is_last_step | Optional[List[bool]] |
Source code in skyrl/train/generators/base.py:34-44
class GeneratorOutput(TypedDict):
prompt_token_ids: List[List[int]]
response_ids: List[List[int]]
rewards: Union[List[float], List[List[float]]]
loss_masks: List[List[int]]
stop_reasons: Optional[List[str]]
rollout_metrics: Optional[Dict[str, Any]]
rollout_logprobs: Optional[List[List[float]]]
trajectory_ids: Optional[List[TrajectoryID]]
# Applicable only for step-wise training
is_last_step: Optional[List[bool]]attr prompt_token_ids
prompt_token_ids: List[List[int]]attr response_ids
response_ids: List[List[int]]attr rewards
rewards: Union[List[float], List[List[float]]]attr loss_masks
loss_masks: List[List[int]]attr stop_reasons
stop_reasons: Optional[List[str]]attr rollout_metrics
rollout_metrics: Optional[Dict[str, Any]]attr rollout_logprobs
rollout_logprobs: Optional[List[List[float]]]attr trajectory_ids
trajectory_ids: Optional[List[TrajectoryID]]attr is_last_step
is_last_step: Optional[List[bool]]class MetricsOutput
Bases: TypedDict
Attributes:
| Name | Type | Description |
|---|---|---|
avg_score | Optional[float] | |
pass_at_n | Optional[float] | |
mean_positive_reward | Optional[float] |
Source code in skyrl/train/generators/base.py:47-50
class MetricsOutput(TypedDict):
avg_score: Optional[float]
pass_at_n: Optional[float]
mean_positive_reward: Optional[float]attr avg_score
avg_score: Optional[float]attr pass_at_n
pass_at_n: Optional[float]attr mean_positive_reward
mean_positive_reward: Optional[float]