SkyRL
API ReferenceSkyRLSkyRL-Train Backend

Data Interface

Data Interface — TensorBatch, TrainingInput, GeneratorInput/Output.

Trainer APIs

class TensorBatch

TensorBatch(*args, **kwargs)

Bases: dict, Generic[DictType]

Base class for training batches

This defines a generic container for a batch of training data (inputs or outputs). Consists of a dictionary of tensors along with some metadata.

Functions:

NameDescription
selectSelect a subset of the data batch.
toMove tensors to device and/or cast to dtype.
contiguousMake the tensors contiguous
repeatRepeat entries in the data batch a specified number of times.
repeat_interleaveRepeat entries in the data batch a specified number of times.
chunkSplit into smaller chunks
sliceSlice the data batch.
saveSave the data to a pickle file
loadLoad the data from a pickle file
catConcatenate shards.

Attributes:

NameTypeDescription
metadataOptional[Dict[str, Any]]
batch_sizeintBatch size for the tensors
devicedeviceGet the device for the tensors
Source code in skyrl/backends/skyrl_train/training_batch.py:15-355
class TensorBatch(dict, Generic[DictType]):
    """Base class for training batches

    This defines a generic container for a batch of training data (inputs or outputs).
    Consists of a dictionary of tensors along with some metadata.
    """

    metadata: Optional[Dict[str, Any]] = None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._batch_size = None
        self._device = None
        self._check_consistency()

    def select(self, keys: List[str], metadata_keys: Optional[List[str]] = None) -> "TensorBatch[DictType]":
        """Select a subset of the data batch.

        Args:
            keys: The keys to select
            metadata_keys: The metadata keys to select

        Returns:
            A new `TensorBatch` object with the selected keys and metadata
        """
        selected_batch_data = {}
        for key in keys:
            selected_batch_data[key] = self[key]
        selected_metadata = {}
        if metadata_keys is None:
            selected_metadata = self.metadata
        else:
            selected_metadata = {}
            for key in metadata_keys:
                selected_metadata[key] = self.metadata[key]
        new_batch = self.__class__(selected_batch_data)
        new_batch.metadata = selected_metadata
        return new_batch

    def _check_consistency(self):
        """Check consistency of all present fields"""
        keys = list(self.keys())
        if len(keys) == 0:
            return

        batch_size = len(self[keys[0]])
        self._batch_size = batch_size
        for key in keys:
            value = self[key]
            if value is None:
                continue
            self._device = value.device if self._device is None else self._device
            if not isinstance(value, torch.Tensor):
                raise ValueError(f"Field {key} must be a tensor, got {type(value)}")
            if len(value) != batch_size:
                raise ValueError(f"Batch size mismatch in {key}")
            if value.device != self._device:
                raise ValueError(f"Device mismatch in {key}. Expected {self._device}, got {value.device}")

    def __getitem__(self, index) -> "TensorBatch[DictType]":
        if isinstance(index, slice):
            return self.slice(index.start, index.stop, index.step)
        elif isinstance(index, int):
            return self.slice(index, index + 1)
        else:
            return super().__getitem__(index)

    def __setitem__(self, key: str, value: Optional[torch.Tensor]) -> None:
        if value is None:
            super().__setitem__(key, value)
            return

        if not isinstance(value, torch.Tensor):
            raise ValueError(f"Field {key} must be a tensor, got {type(value)}")

        if hasattr(self, "_batch_size") and self._batch_size is not None and len(value) != self._batch_size:
            raise ValueError(
                f"Batch size mismatch in {key}. Expected tensor to be of size {self._batch_size}, got {len(value)}."
            )

        super().__setitem__(key, value)

        if hasattr(self, "_batch_size") and self._batch_size is None:
            self._batch_size = len(value)

    def to(
        self, device: torch.device = None, dtype: torch.dtype = None, *, non_blocking: bool = False
    ) -> "TensorBatch":
        """Move tensors to device and/or cast to dtype.

        Args:
            device: The device to move the tensors to
            dtype: The dtype to cast the tensors to
            non_blocking: Whether the operation should be non-blocking
        """
        for key, value in self.items():
            if value is None:
                continue
            assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
            self[key] = value.to(device, dtype, non_blocking=non_blocking)
        return self

    def contiguous(self) -> "TensorBatch":
        """Make the tensors contiguous"""
        for key, value in self.items():
            if value is None:
                continue
            # some of these asserts are not needed, but it's kept for type safety
            assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
            self[key] = value.contiguous()
        return self

    @property
    def batch_size(self) -> int:
        """Batch size for the tensors"""
        return self._batch_size

    @property
    def device(self) -> torch.device:
        """Get the device for the tensors"""
        return self._device

    def __getstate__(self):
        """Serialize the `TensorBatch` object for pickle protocol.

        Uses fast numpy-based serialization when possible, with fallback to torch.save
        for dtypes not supported by numpy (e.g., bfloat16).
        """
        self.contiguous()
        if self._device is not None:
            assert self._device == torch.device("cpu"), "Tensors must be on CPU before serialization"
        batch_dict = {}
        for key, value in self.items():
            if value is None:
                batch_dict[key] = None
            else:
                try:
                    # Fast path: direct memory copy via numpy (works for most dtypes)
                    arr = value.numpy()
                    batch_dict[key] = {
                        "format": "numpy",
                        "data": arr.tobytes(),
                        "shape": arr.shape,
                        "dtype": str(arr.dtype),
                    }
                except TypeError:
                    # Fallback for dtypes not supported by numpy (e.g., bfloat16)
                    buffer = io.BytesIO()
                    torch.save(value, buffer)
                    batch_dict[key] = {
                        "format": "torch",
                        "data": buffer.getvalue(),
                    }

        return {
            "batch_dict": batch_dict,
            "batch_size": self._batch_size,
            "device": self._device,
            "metadata": self.metadata,
        }

    def __setstate__(self, state):
        """Deserialize the `TensorBatch` object and load it into memory.

        Handles both numpy-based format (fast path) and torch format (fallback for bfloat16 etc).
        """
        for key, value in state["batch_dict"].items():
            if value is None:
                self[key] = None
            elif value.get("format") == "torch":
                # Fallback path: torch.load for unsupported dtypes
                buffer = io.BytesIO(value["data"])
                self[key] = torch.load(buffer, weights_only=True)
            else:
                # Fast path: reconstruct from numpy bytes
                # Also handles legacy format without "format" key
                arr = np.frombuffer(value["data"], dtype=np.dtype(value["dtype"]))
                arr = arr.reshape(value["shape"])
                # Convert to tensor (makes a copy, which is needed since frombuffer is read-only)
                self[key] = torch.from_numpy(arr.copy())

        self._batch_size = state["batch_size"]
        self._device = state["device"]
        self.metadata = state["metadata"]
        self._check_consistency()
        return self

    def repeat(self, repeats: int):
        """Repeat entries in the data batch a specified number of times.

        This is similar to `torch.repeat` (and `numpy.tile`). `metadata` is not repeated.

        Args:
            repeats: The number of times to repeat the data batch

        Returns:
            A new `TensorBatch` object with the data repeated
        """
        new_batch = {}
        for key, value in self.items():
            if value is None:
                new_batch[key] = value
            else:
                assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
                new_batch[key] = value.repeat(repeats)
        new_batch = self.__class__(new_batch)
        new_batch.metadata = self.metadata
        return new_batch

    def repeat_interleave(self, repeats: int):
        """Repeat entries in the data batch a specified number of times.

        This is similar to `torch.repeat_interleave` (and `numpy.repeat`). `metadata` is not repeated.

        Args:
            repeats: The number of times to repeat the data batch

        Returns:
            A new `TensorBatch` object with the data repeated
        """
        new_batch = {}
        for key, value in self.items():
            if value is None:
                new_batch[key] = value
            else:
                assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
                new_batch[key] = value.repeat_interleave(repeats)
        new_batch = self.__class__(new_batch)
        new_batch.metadata = self.metadata
        return new_batch

    def chunk(self, chunk_size: int) -> List["TensorBatch[DictType]"]:
        """Split into smaller chunks"""
        chunks = []
        for i in range(0, self.batch_size, chunk_size):
            chunk_data = {}
            for key, value in self.items():
                if value is not None:
                    if isinstance(value, torch.Tensor):
                        chunk_data[key] = value[i : i + chunk_size]
                    else:
                        raise ValueError(f"Unsupported type {type(value)} for key {key}")
                else:
                    # `None` values are not chunked
                    chunk_data[key] = value
            chunk = self.__class__(chunk_data)
            chunk.metadata = self.metadata
            chunks.append(chunk)
        return chunks

    def slice(self, start: int, end: int, step: int = 1) -> "TensorBatch[DictType]":
        """Slice the data batch.

        Args:
            start: The start index
            end: The end index
            step: The step size

        Returns:
            A new `TensorBatch` object with the view of the specified slice.
        """
        slice_obj = slice(start, end, step)
        sliced_data = {}
        for key, value in self.items():
            if value is not None:
                if isinstance(value, torch.Tensor):
                    sliced_data[key] = value[slice_obj]
                else:
                    raise ValueError(f"Unsupported type {type(value)} for key {key}")
            else:
                # `None` values are not sliced
                sliced_data[key] = value
        sliced_batch = self.__class__(sliced_data)
        sliced_batch.metadata = self.metadata
        return sliced_batch

    def save(self, path: str):
        """Save the data to a pickle file"""
        with open(path, "wb") as f:
            pickle.dump(self, f)

    def load(self, path: str):
        """Load the data from a pickle file"""
        with open(path, "rb") as f:
            return pickle.load(f)

    @classmethod
    def cat(cls, shards: List["TensorBatch[DictType]"]) -> "TensorBatch[DictType]":
        """Concatenate shards.

        Args:
            shards: The list of `TensorBatch` objects to cat

        Returns:
            A new `TensorBatch` object with the concatenated data
        """
        cat_data = {}
        assert len(shards) > 0, "Cannot cat an empty list of shards"
        for key, value in shards[0].items():
            if value is not None:
                if isinstance(value, torch.Tensor):
                    cat_data[key] = torch.cat([shard[key] for shard in shards])
                else:
                    raise ValueError(f"Unsupported type {type(value)} for key {key}")
            else:
                # `None` values are not cat'd
                cat_data[key] = value
        metadata = shards[0].metadata
        cat_batch = cls(cat_data)
        cat_batch.metadata = metadata
        return cat_batch

    def __len__(self) -> int:
        """Length of the batch.

        Note that this is the same as the batch size rather than the number of keys in the batch.
        """
        return self._batch_size

    def __eq__(self, other: Any) -> bool:
        """Check if two `TensorBatch` objects are equal"""
        if not isinstance(other, TensorBatch):
            return False
        if self.metadata != other.metadata:
            return False
        if len(self) != len(other):
            return False
        if len(self.items()) != len(other.items()):
            return False
        for k, v in self.items():
            if k not in other or not torch.equal(v, other[k]):
                return False
        return True

    def __str__(self) -> str:
        """String representation of the `TensorBatch` object"""
        return f"TensorBatch(batch_size={self.batch_size}, device={self.device}, metadata={self.metadata}), items={self.items()}"

    def __repr__(self) -> str:
        """String representation of the `TensorBatch` object"""
        return self.__str__()

attr metadata

metadata: Optional[Dict[str, Any]] = None

method select

select(keys: List[str], metadata_keys: Optional[List[str]] = None) -> TensorBatch[DictType]

Select a subset of the data batch.

Parameters:

NameTypeDescriptionDefault
keysList[str]The keys to selectrequired
metadata_keysOptional[List[str]]The metadata keys to selectNone

Returns:

TypeDescription
TensorBatch[DictType]A new TensorBatch object with the selected keys and metadata
Source code in skyrl/backends/skyrl_train/training_batch.py:30-52
    def select(self, keys: List[str], metadata_keys: Optional[List[str]] = None) -> "TensorBatch[DictType]":
        """Select a subset of the data batch.

        Args:
            keys: The keys to select
            metadata_keys: The metadata keys to select

        Returns:
            A new `TensorBatch` object with the selected keys and metadata
        """
        selected_batch_data = {}
        for key in keys:
            selected_batch_data[key] = self[key]
        selected_metadata = {}
        if metadata_keys is None:
            selected_metadata = self.metadata
        else:
            selected_metadata = {}
            for key in metadata_keys:
                selected_metadata[key] = self.metadata[key]
        new_batch = self.__class__(selected_batch_data)
        new_batch.metadata = selected_metadata
        return new_batch

method to

to(device: torch.device = None, dtype: torch.dtype = None, *, non_blocking: bool = False) -> TensorBatch

Move tensors to device and/or cast to dtype.

Parameters:

NameTypeDescriptionDefault
devicedeviceThe device to move the tensors toNone
dtypedtypeThe dtype to cast the tensors toNone
non_blockingboolWhether the operation should be non-blockingFalse
Source code in skyrl/backends/skyrl_train/training_batch.py:100-115
    def to(
        self, device: torch.device = None, dtype: torch.dtype = None, *, non_blocking: bool = False
    ) -> "TensorBatch":
        """Move tensors to device and/or cast to dtype.

        Args:
            device: The device to move the tensors to
            dtype: The dtype to cast the tensors to
            non_blocking: Whether the operation should be non-blocking
        """
        for key, value in self.items():
            if value is None:
                continue
            assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
            self[key] = value.to(device, dtype, non_blocking=non_blocking)
        return self

method contiguous

contiguous() -> TensorBatch

Make the tensors contiguous

Source code in skyrl/backends/skyrl_train/training_batch.py:117-125
    def contiguous(self) -> "TensorBatch":
        """Make the tensors contiguous"""
        for key, value in self.items():
            if value is None:
                continue
            # some of these asserts are not needed, but it's kept for type safety
            assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
            self[key] = value.contiguous()
        return self

attr property batch_size

batch_size: int

Batch size for the tensors

attr property device

device: torch.device

Get the device for the tensors

method repeat

repeat(repeats: int)

Repeat entries in the data batch a specified number of times.

This is similar to torch.repeat (and numpy.tile). metadata is not repeated.

Parameters:

NameTypeDescriptionDefault
repeatsintThe number of times to repeat the data batchrequired

Returns:

TypeDescription
A new TensorBatch object with the data repeated
Source code in skyrl/backends/skyrl_train/training_batch.py:202-222
    def repeat(self, repeats: int):
        """Repeat entries in the data batch a specified number of times.

        This is similar to `torch.repeat` (and `numpy.tile`). `metadata` is not repeated.

        Args:
            repeats: The number of times to repeat the data batch

        Returns:
            A new `TensorBatch` object with the data repeated
        """
        new_batch = {}
        for key, value in self.items():
            if value is None:
                new_batch[key] = value
            else:
                assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
                new_batch[key] = value.repeat(repeats)
        new_batch = self.__class__(new_batch)
        new_batch.metadata = self.metadata
        return new_batch

method repeat_interleave

repeat_interleave(repeats: int)

Repeat entries in the data batch a specified number of times.

This is similar to torch.repeat_interleave (and numpy.repeat). metadata is not repeated.

Parameters:

NameTypeDescriptionDefault
repeatsintThe number of times to repeat the data batchrequired

Returns:

TypeDescription
A new TensorBatch object with the data repeated
Source code in skyrl/backends/skyrl_train/training_batch.py:224-244
    def repeat_interleave(self, repeats: int):
        """Repeat entries in the data batch a specified number of times.

        This is similar to `torch.repeat_interleave` (and `numpy.repeat`). `metadata` is not repeated.

        Args:
            repeats: The number of times to repeat the data batch

        Returns:
            A new `TensorBatch` object with the data repeated
        """
        new_batch = {}
        for key, value in self.items():
            if value is None:
                new_batch[key] = value
            else:
                assert isinstance(value, torch.Tensor), f"Field {key} must be a tensor, got {type(value)}"
                new_batch[key] = value.repeat_interleave(repeats)
        new_batch = self.__class__(new_batch)
        new_batch.metadata = self.metadata
        return new_batch

method chunk

chunk(chunk_size: int) -> List[TensorBatch[DictType]]

Split into smaller chunks

Source code in skyrl/backends/skyrl_train/training_batch.py:246-263
    def chunk(self, chunk_size: int) -> List["TensorBatch[DictType]"]:
        """Split into smaller chunks"""
        chunks = []
        for i in range(0, self.batch_size, chunk_size):
            chunk_data = {}
            for key, value in self.items():
                if value is not None:
                    if isinstance(value, torch.Tensor):
                        chunk_data[key] = value[i : i + chunk_size]
                    else:
                        raise ValueError(f"Unsupported type {type(value)} for key {key}")
                else:
                    # `None` values are not chunked
                    chunk_data[key] = value
            chunk = self.__class__(chunk_data)
            chunk.metadata = self.metadata
            chunks.append(chunk)
        return chunks

method slice

slice(start: int, end: int, step: int = 1) -> TensorBatch[DictType]

Slice the data batch.

Parameters:

NameTypeDescriptionDefault
startintThe start indexrequired
endintThe end indexrequired
stepintThe step size1

Returns:

TypeDescription
TensorBatch[DictType]A new TensorBatch object with the view of the specified slice.
Source code in skyrl/backends/skyrl_train/training_batch.py:265-289
    def slice(self, start: int, end: int, step: int = 1) -> "TensorBatch[DictType]":
        """Slice the data batch.

        Args:
            start: The start index
            end: The end index
            step: The step size

        Returns:
            A new `TensorBatch` object with the view of the specified slice.
        """
        slice_obj = slice(start, end, step)
        sliced_data = {}
        for key, value in self.items():
            if value is not None:
                if isinstance(value, torch.Tensor):
                    sliced_data[key] = value[slice_obj]
                else:
                    raise ValueError(f"Unsupported type {type(value)} for key {key}")
            else:
                # `None` values are not sliced
                sliced_data[key] = value
        sliced_batch = self.__class__(sliced_data)
        sliced_batch.metadata = self.metadata
        return sliced_batch

method save

save(path: str)

Save the data to a pickle file

Source code in skyrl/backends/skyrl_train/training_batch.py:291-294
    def save(self, path: str):
        """Save the data to a pickle file"""
        with open(path, "wb") as f:
            pickle.dump(self, f)

method load

load(path: str)

Load the data from a pickle file

Source code in skyrl/backends/skyrl_train/training_batch.py:296-299
    def load(self, path: str):
        """Load the data from a pickle file"""
        with open(path, "rb") as f:
            return pickle.load(f)

method classmethod cat

cat(shards: List[TensorBatch[DictType]]) -> TensorBatch[DictType]

Concatenate shards.

Parameters:

NameTypeDescriptionDefault
shardsList[TensorBatch[DictType]]The list of TensorBatch objects to catrequired

Returns:

TypeDescription
TensorBatch[DictType]A new TensorBatch object with the concatenated data
Source code in skyrl/backends/skyrl_train/training_batch.py:301-325
    @classmethod
    def cat(cls, shards: List["TensorBatch[DictType]"]) -> "TensorBatch[DictType]":
        """Concatenate shards.

        Args:
            shards: The list of `TensorBatch` objects to cat

        Returns:
            A new `TensorBatch` object with the concatenated data
        """
        cat_data = {}
        assert len(shards) > 0, "Cannot cat an empty list of shards"
        for key, value in shards[0].items():
            if value is not None:
                if isinstance(value, torch.Tensor):
                    cat_data[key] = torch.cat([shard[key] for shard in shards])
                else:
                    raise ValueError(f"Unsupported type {type(value)} for key {key}")
            else:
                # `None` values are not cat'd
                cat_data[key] = value
        metadata = shards[0].metadata
        cat_batch = cls(cat_data)
        cat_batch.metadata = metadata
        return cat_batch

class TrainingInput

Bases: TypedDict

Schema for training input batch

Attributes:

NameTypeDescription
sequencesInteger[Tensor, 'batch_size seq_len']
attention_maskInteger[Tensor, 'batch_size seq_len']
loss_maskInteger[Tensor, 'batch_size seq_len']
response_maskInteger[Tensor, 'batch_size seq_len']
action_log_probsFloat[Tensor, 'batch_size seq_len']
base_action_log_probsFloat[Tensor, 'batch_size seq_len']
valuesOptional[Float[Tensor, 'batch_size seq_len']]
returnsFloat[Tensor, 'batch_size seq_len']
advantagesFloat[Tensor, 'batch_size seq_len']
klFloat[Tensor, 'batch_size seq_len']
rewardsOptional[Float[Tensor, 'batch_size seq_len']]
rollout_logprobsOptional[Float[Tensor, 'batch_size seq_len']]
Source code in skyrl/backends/skyrl_train/training_batch.py:358-372
class TrainingInput(TypedDict, total=False):
    """Schema for training input batch"""

    sequences: Integer[torch.Tensor, "batch_size seq_len"]
    attention_mask: Integer[torch.Tensor, "batch_size seq_len"]
    loss_mask: Integer[torch.Tensor, "batch_size seq_len"]
    response_mask: Integer[torch.Tensor, "batch_size seq_len"]
    action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
    base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
    values: Optional[Float[torch.Tensor, "batch_size seq_len"]]
    returns: Float[torch.Tensor, "batch_size seq_len"]
    advantages: Float[torch.Tensor, "batch_size seq_len"]
    kl: Float[torch.Tensor, "batch_size seq_len"]
    rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]]
    rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]]

attr sequences

sequences: Integer[torch.Tensor, 'batch_size seq_len']

attr attention_mask

attention_mask: Integer[torch.Tensor, 'batch_size seq_len']

attr loss_mask

loss_mask: Integer[torch.Tensor, 'batch_size seq_len']

attr response_mask

response_mask: Integer[torch.Tensor, 'batch_size seq_len']

attr action_log_probs

action_log_probs: Float[torch.Tensor, 'batch_size seq_len']

attr base_action_log_probs

base_action_log_probs: Float[torch.Tensor, 'batch_size seq_len']

attr values

values: Optional[Float[torch.Tensor, 'batch_size seq_len']]

attr returns

returns: Float[torch.Tensor, 'batch_size seq_len']

attr advantages

advantages: Float[torch.Tensor, 'batch_size seq_len']

attr kl

kl: Float[torch.Tensor, 'batch_size seq_len']

attr rewards

rewards: Optional[Float[torch.Tensor, 'batch_size seq_len']]

attr rollout_logprobs

rollout_logprobs: Optional[Float[torch.Tensor, 'batch_size seq_len']]

class TrainingInputBatch

Bases: TensorBatch[TrainingInput]

Training input data

Functions:

NameDescription
selectSelect a subset of the data batch.
toMove tensors to device and/or cast to dtype.
contiguousMake the tensors contiguous
repeatRepeat entries in the data batch a specified number of times.
repeat_interleaveRepeat entries in the data batch a specified number of times.
chunkSplit into smaller chunks
sliceSlice the data batch.
saveSave the data to a pickle file
loadLoad the data from a pickle file
catConcatenate shards.

Attributes:

NameTypeDescription
metadataOptional[Dict[str, Any]]
batch_sizeintBatch size for the tensors
devicedeviceGet the device for the tensors

class TrainingOutputBatch

Bases: TensorBatch[Dict[str, Tensor]]

Training output data

Functions:

NameDescription
selectSelect a subset of the data batch.
toMove tensors to device and/or cast to dtype.
contiguousMake the tensors contiguous
repeatRepeat entries in the data batch a specified number of times.
repeat_interleaveRepeat entries in the data batch a specified number of times.
chunkSplit into smaller chunks
sliceSlice the data batch.
saveSave the data to a pickle file
loadLoad the data from a pickle file
catConcatenate shards.

Attributes:

NameTypeDescription
metadataOptional[Dict[str, Any]]
batch_sizeintBatch size for the tensors
devicedeviceGet the device for the tensors

Generator APIs

class GeneratorInput

Bases: TypedDict

Attributes:

NameTypeDescription
promptsList[ConversationType]
env_classesList[str]
env_extrasOptional[List[Dict[str, Any]]]
sampling_paramsOptional[Dict[str, Any]]
trajectory_idsOptional[List[TrajectoryID]]
batch_metadataOptional[BatchMetadata]
Source code in skyrl/train/generators/base.py:25-31
class GeneratorInput(TypedDict):
    prompts: List[ConversationType]
    env_classes: List[str]
    env_extras: Optional[List[Dict[str, Any]]]
    sampling_params: Optional[Dict[str, Any]]
    trajectory_ids: Optional[List[TrajectoryID]]
    batch_metadata: Optional[BatchMetadata]

attr prompts

prompts: List[ConversationType]

attr env_classes

env_classes: List[str]

attr env_extras

env_extras: Optional[List[Dict[str, Any]]]

attr sampling_params

sampling_params: Optional[Dict[str, Any]]

attr trajectory_ids

trajectory_ids: Optional[List[TrajectoryID]]

attr batch_metadata

batch_metadata: Optional[BatchMetadata]

class GeneratorOutput

Bases: TypedDict

Attributes:

NameTypeDescription
prompt_token_idsList[List[int]]
response_idsList[List[int]]
rewardsUnion[List[float], List[List[float]]]
loss_masksList[List[int]]
stop_reasonsOptional[List[str]]
rollout_metricsOptional[Dict[str, Any]]
rollout_logprobsOptional[List[List[float]]]
trajectory_idsOptional[List[TrajectoryID]]
is_last_stepOptional[List[bool]]
Source code in skyrl/train/generators/base.py:34-44
class GeneratorOutput(TypedDict):
    prompt_token_ids: List[List[int]]
    response_ids: List[List[int]]
    rewards: Union[List[float], List[List[float]]]
    loss_masks: List[List[int]]
    stop_reasons: Optional[List[str]]
    rollout_metrics: Optional[Dict[str, Any]]
    rollout_logprobs: Optional[List[List[float]]]
    trajectory_ids: Optional[List[TrajectoryID]]
    # Applicable only for step-wise training
    is_last_step: Optional[List[bool]]

attr prompt_token_ids

prompt_token_ids: List[List[int]]

attr response_ids

response_ids: List[List[int]]

attr rewards

rewards: Union[List[float], List[List[float]]]

attr loss_masks

loss_masks: List[List[int]]

attr stop_reasons

stop_reasons: Optional[List[str]]

attr rollout_metrics

rollout_metrics: Optional[Dict[str, Any]]

attr rollout_logprobs

rollout_logprobs: Optional[List[List[float]]]

attr trajectory_ids

trajectory_ids: Optional[List[TrajectoryID]]

attr is_last_step

is_last_step: Optional[List[bool]]

class MetricsOutput

Bases: TypedDict

Attributes:

NameTypeDescription
avg_scoreOptional[float]
pass_at_nOptional[float]
mean_positive_rewardOptional[float]
Source code in skyrl/train/generators/base.py:47-50
class MetricsOutput(TypedDict):
    avg_score: Optional[float]
    pass_at_n: Optional[float]
    mean_positive_reward: Optional[float]

attr avg_score

avg_score: Optional[float]

attr pass_at_n

pass_at_n: Optional[float]

attr mean_positive_reward

mean_positive_reward: Optional[float]

On this page