SkyRL
API ReferenceSkyRL

TX Models

Model loading and configuration for the JAX backend.

Model Configuration

class ModelConfig

ModelConfig(config: PretrainedConfig | dict, *, max_lora_adapters: int, max_lora_rank: int, shard_attention_heads: bool, loss_chunk_size: int = 0, gradient_checkpointing: bool = False, mhc_expansion_rate: int = 1)

Bases: PretrainedConfig

Configuration for skyrl models with LoRA support.

Wraps a HuggingFace PretrainedConfig with additional parameters for Multi-LoRA training and tensor parallelism.

Parameters:

NameTypeDescriptionDefault
configPretrainedConfigdictA HuggingFace PretrainedConfig object (e.g., from AutoConfig.from_pretrained())
max_lora_adaptersintMaximum number of concurrent LoRA adaptersrequired
max_lora_rankintMaximum rank for LoRA adaptersrequired
shard_attention_headsboolWhether to shard attention across tensor parallel devicesrequired
loss_chunk_sizeintChunk size for cross-entropy loss computation (0 = no chunking)0
gradient_checkpointingboolRecompute activations during backward to save memoryFalse
mhc_expansion_rateintmHC expansion rate. Connectors are trainable when this is > 1.1

Functions:

NameDescription
get_configReturn text_config when present, otherwise return this config.
get_text_configReturn a wrapped config built from self.text_config.
get_num_experts

Attributes:

Source code in skyrl/tx/models/configs.py:6-71
class ModelConfig(PretrainedConfig):
    """Configuration for skyrl models with LoRA support.

    Wraps a HuggingFace PretrainedConfig with additional parameters
    for Multi-LoRA training and tensor parallelism.

    Args:
        config: A HuggingFace PretrainedConfig object (e.g., from AutoConfig.from_pretrained())
        max_lora_adapters: Maximum number of concurrent LoRA adapters
        max_lora_rank: Maximum rank for LoRA adapters
        shard_attention_heads: Whether to shard attention across tensor parallel devices
        loss_chunk_size: Chunk size for cross-entropy loss computation (0 = no chunking)
        gradient_checkpointing: Recompute activations during backward to save memory
        mhc_expansion_rate: mHC expansion rate. Connectors are trainable when this is > 1.
    """

    # Type hints for config attributes
    max_lora_adapters: int
    max_lora_rank: int
    shard_attention_heads: bool
    loss_chunk_size: int
    gradient_checkpointing: bool
    mhc_expansion_rate: int

    def __init__(
        self,
        config: PretrainedConfig | dict,
        *,
        max_lora_adapters: int,
        max_lora_rank: int,
        shard_attention_heads: bool,
        loss_chunk_size: int = 0,
        gradient_checkpointing: bool = False,
        mhc_expansion_rate: int = 1,
    ):
        # `text_config` can come through as a raw dict from HF configs.
        super().__init__(**(config if isinstance(config, dict) else config.__dict__))

        # Add LoRA-specific parameters
        self.max_lora_adapters = max_lora_adapters
        self.max_lora_rank = max_lora_rank
        self.shard_attention_heads = shard_attention_heads
        self.loss_chunk_size = loss_chunk_size
        self.gradient_checkpointing = gradient_checkpointing
        self.mhc_expansion_rate = mhc_expansion_rate

    def get_config(self) -> PretrainedConfig:
        """Return `text_config` when present, otherwise return this config."""
        return self.get_text_config() if hasattr(self, "text_config") else self

    def get_text_config(self) -> "ModelConfig":
        """Return a wrapped config built from `self.text_config`."""
        return type(self)(
            self.text_config,
            max_lora_adapters=self.max_lora_adapters,
            max_lora_rank=self.max_lora_rank,
            shard_attention_heads=self.shard_attention_heads,
            loss_chunk_size=self.loss_chunk_size,
            gradient_checkpointing=self.gradient_checkpointing,
            mhc_expansion_rate=self.mhc_expansion_rate,
        )

    def get_num_experts(self):
        # TODO: Change this if there can be different numbers of experts in text_config and vision_config
        config = self.get_config()
        return getattr(config, "num_experts", None) or getattr(config, "n_routed_experts", None)

attr max_lora_adapters

max_lora_adapters: int = max_lora_adapters

attr max_lora_rank

max_lora_rank: int = max_lora_rank

attr shard_attention_heads

shard_attention_heads: bool = shard_attention_heads

attr loss_chunk_size

loss_chunk_size: int = loss_chunk_size

attr gradient_checkpointing

gradient_checkpointing: bool = gradient_checkpointing

attr mhc_expansion_rate

mhc_expansion_rate: int = mhc_expansion_rate

method get_config

get_config() -> PretrainedConfig

Return text_config when present, otherwise return this config.

Source code in skyrl/tx/models/configs.py:52-54
    def get_config(self) -> PretrainedConfig:
        """Return `text_config` when present, otherwise return this config."""
        return self.get_text_config() if hasattr(self, "text_config") else self

method get_text_config

get_text_config() -> ModelConfig

Return a wrapped config built from self.text_config.

Source code in skyrl/tx/models/configs.py:56-66
    def get_text_config(self) -> "ModelConfig":
        """Return a wrapped config built from `self.text_config`."""
        return type(self)(
            self.text_config,
            max_lora_adapters=self.max_lora_adapters,
            max_lora_rank=self.max_lora_rank,
            shard_attention_heads=self.shard_attention_heads,
            loss_chunk_size=self.loss_chunk_size,
            gradient_checkpointing=self.gradient_checkpointing,
            mhc_expansion_rate=self.mhc_expansion_rate,
        )

method get_num_experts

get_num_experts()
Source code in skyrl/tx/models/configs.py:68-71
    def get_num_experts(self):
        # TODO: Change this if there can be different numbers of experts in text_config and vision_config
        config = self.get_config()
        return getattr(config, "num_experts", None) or getattr(config, "n_routed_experts", None)

Model Interface

class ModelForCausalLM

Functions:

NameDescription
get_model_config
is_lora_paramReturn True if a parameter path corresponds to trainable LoRA/connector weights.

Attributes:

NameTypeDescription
configModelConfig
Source code in skyrl/tx/models/types.py:13-26
class ModelForCausalLM:

    config: ModelConfig

    def get_model_config(self) -> ModelConfig:
        return self.config

    def is_lora_param(self, path: tuple, _value) -> bool:
        """Return True if a parameter path corresponds to trainable LoRA/connector weights."""
        is_lora = any(name in path for name in ("lora_A", "lora_B"))
        is_connector = self.config.mhc_expansion_rate > 1 and any(
            name in path for name in ("attn_connector", "mlp_connector")
        )
        return is_lora or is_connector

attr config

config: ModelConfig

method get_model_config

get_model_config() -> ModelConfig
Source code in skyrl/tx/models/types.py:17-18
    def get_model_config(self) -> ModelConfig:
        return self.config

method is_lora_param

is_lora_param(path: tuple, _value: tuple) -> bool

Return True if a parameter path corresponds to trainable LoRA/connector weights.

Source code in skyrl/tx/models/types.py:20-26
    def is_lora_param(self, path: tuple, _value) -> bool:
        """Return True if a parameter path corresponds to trainable LoRA/connector weights."""
        is_lora = any(name in path for name in ("lora_A", "lora_B"))
        is_connector = self.config.mhc_expansion_rate > 1 and any(
            name in path for name in ("attn_connector", "mlp_connector")
        )
        return is_lora or is_connector

class CausalLMOutput

CausalLMOutput(last_hidden_state: jax.Array, kv_cache: KVCache | None, hidden_states: list[jax.Array] | None = None) -> None

Output type for causal language models like Qwen3ForCausalLM.

Attributes:

NameTypeDescription
last_hidden_stateArrayThe last hidden state from the model.
kv_cacheKVCacheNone
hidden_stateslist[Array]None
Source code in skyrl/tx/models/types.py:45-58
@jax.tree_util.register_dataclass
@dataclass
class CausalLMOutput:
    """Output type for causal language models like Qwen3ForCausalLM.

    Attributes:
        last_hidden_state: The last hidden state from the model.
        kv_cache: The updated key-value cache (None during training).
        hidden_states: All hidden states, if output_hidden_states=True.
    """

    last_hidden_state: jax.Array
    kv_cache: KVCache | None
    hidden_states: list[jax.Array] | None = None

attr last_hidden_state

last_hidden_state: jax.Array

attr kv_cache

kv_cache: KVCache | None

attr hidden_states

hidden_states: list[jax.Array] | None = None

On this page