API ReferenceSkyRL
TX Models
Model loading and configuration for the JAX backend.
Model Configuration
class ModelConfig
ModelConfig(config: PretrainedConfig | dict, *, max_lora_adapters: int, max_lora_rank: int, shard_attention_heads: bool, loss_chunk_size: int = 0, gradient_checkpointing: bool = False, mhc_expansion_rate: int = 1)Bases: PretrainedConfig
Configuration for skyrl models with LoRA support.
Wraps a HuggingFace PretrainedConfig with additional parameters for Multi-LoRA training and tensor parallelism.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | PretrainedConfig | dict | A HuggingFace PretrainedConfig object (e.g., from AutoConfig.from_pretrained()) |
max_lora_adapters | int | Maximum number of concurrent LoRA adapters | required |
max_lora_rank | int | Maximum rank for LoRA adapters | required |
shard_attention_heads | bool | Whether to shard attention across tensor parallel devices | required |
loss_chunk_size | int | Chunk size for cross-entropy loss computation (0 = no chunking) | 0 |
gradient_checkpointing | bool | Recompute activations during backward to save memory | False |
mhc_expansion_rate | int | mHC expansion rate. Connectors are trainable when this is > 1. | 1 |
Functions:
| Name | Description |
|---|---|
get_config | Return text_config when present, otherwise return this config. |
get_text_config | Return a wrapped config built from self.text_config. |
get_num_experts |
Attributes:
| Name | Type | Description |
|---|---|---|
max_lora_adapters | int | |
max_lora_rank | int | |
shard_attention_heads | bool | |
loss_chunk_size | int | |
gradient_checkpointing | bool | |
mhc_expansion_rate | int |
Source code in skyrl/tx/models/configs.py:6-71
class ModelConfig(PretrainedConfig):
"""Configuration for skyrl models with LoRA support.
Wraps a HuggingFace PretrainedConfig with additional parameters
for Multi-LoRA training and tensor parallelism.
Args:
config: A HuggingFace PretrainedConfig object (e.g., from AutoConfig.from_pretrained())
max_lora_adapters: Maximum number of concurrent LoRA adapters
max_lora_rank: Maximum rank for LoRA adapters
shard_attention_heads: Whether to shard attention across tensor parallel devices
loss_chunk_size: Chunk size for cross-entropy loss computation (0 = no chunking)
gradient_checkpointing: Recompute activations during backward to save memory
mhc_expansion_rate: mHC expansion rate. Connectors are trainable when this is > 1.
"""
# Type hints for config attributes
max_lora_adapters: int
max_lora_rank: int
shard_attention_heads: bool
loss_chunk_size: int
gradient_checkpointing: bool
mhc_expansion_rate: int
def __init__(
self,
config: PretrainedConfig | dict,
*,
max_lora_adapters: int,
max_lora_rank: int,
shard_attention_heads: bool,
loss_chunk_size: int = 0,
gradient_checkpointing: bool = False,
mhc_expansion_rate: int = 1,
):
# `text_config` can come through as a raw dict from HF configs.
super().__init__(**(config if isinstance(config, dict) else config.__dict__))
# Add LoRA-specific parameters
self.max_lora_adapters = max_lora_adapters
self.max_lora_rank = max_lora_rank
self.shard_attention_heads = shard_attention_heads
self.loss_chunk_size = loss_chunk_size
self.gradient_checkpointing = gradient_checkpointing
self.mhc_expansion_rate = mhc_expansion_rate
def get_config(self) -> PretrainedConfig:
"""Return `text_config` when present, otherwise return this config."""
return self.get_text_config() if hasattr(self, "text_config") else self
def get_text_config(self) -> "ModelConfig":
"""Return a wrapped config built from `self.text_config`."""
return type(self)(
self.text_config,
max_lora_adapters=self.max_lora_adapters,
max_lora_rank=self.max_lora_rank,
shard_attention_heads=self.shard_attention_heads,
loss_chunk_size=self.loss_chunk_size,
gradient_checkpointing=self.gradient_checkpointing,
mhc_expansion_rate=self.mhc_expansion_rate,
)
def get_num_experts(self):
# TODO: Change this if there can be different numbers of experts in text_config and vision_config
config = self.get_config()
return getattr(config, "num_experts", None) or getattr(config, "n_routed_experts", None)attr max_lora_adapters
max_lora_adapters: int = max_lora_adaptersattr max_lora_rank
max_lora_rank: int = max_lora_rankattr shard_attention_heads
shard_attention_heads: bool = shard_attention_headsattr loss_chunk_size
loss_chunk_size: int = loss_chunk_sizeattr gradient_checkpointing
gradient_checkpointing: bool = gradient_checkpointingattr mhc_expansion_rate
mhc_expansion_rate: int = mhc_expansion_ratemethod get_config
get_config() -> PretrainedConfigReturn text_config when present, otherwise return this config.
Source code in skyrl/tx/models/configs.py:52-54
def get_config(self) -> PretrainedConfig:
"""Return `text_config` when present, otherwise return this config."""
return self.get_text_config() if hasattr(self, "text_config") else selfmethod get_text_config
get_text_config() -> ModelConfigReturn a wrapped config built from self.text_config.
Source code in skyrl/tx/models/configs.py:56-66
def get_text_config(self) -> "ModelConfig":
"""Return a wrapped config built from `self.text_config`."""
return type(self)(
self.text_config,
max_lora_adapters=self.max_lora_adapters,
max_lora_rank=self.max_lora_rank,
shard_attention_heads=self.shard_attention_heads,
loss_chunk_size=self.loss_chunk_size,
gradient_checkpointing=self.gradient_checkpointing,
mhc_expansion_rate=self.mhc_expansion_rate,
)method get_num_experts
get_num_experts()Source code in skyrl/tx/models/configs.py:68-71
def get_num_experts(self):
# TODO: Change this if there can be different numbers of experts in text_config and vision_config
config = self.get_config()
return getattr(config, "num_experts", None) or getattr(config, "n_routed_experts", None)Model Interface
class ModelForCausalLM
Functions:
| Name | Description |
|---|---|
get_model_config | |
is_lora_param | Return True if a parameter path corresponds to trainable LoRA/connector weights. |
Attributes:
| Name | Type | Description |
|---|---|---|
config | ModelConfig |
Source code in skyrl/tx/models/types.py:13-26
class ModelForCausalLM:
config: ModelConfig
def get_model_config(self) -> ModelConfig:
return self.config
def is_lora_param(self, path: tuple, _value) -> bool:
"""Return True if a parameter path corresponds to trainable LoRA/connector weights."""
is_lora = any(name in path for name in ("lora_A", "lora_B"))
is_connector = self.config.mhc_expansion_rate > 1 and any(
name in path for name in ("attn_connector", "mlp_connector")
)
return is_lora or is_connectorattr config
config: ModelConfigmethod get_model_config
get_model_config() -> ModelConfigSource code in skyrl/tx/models/types.py:17-18
def get_model_config(self) -> ModelConfig:
return self.configmethod is_lora_param
is_lora_param(path: tuple, _value: tuple) -> boolReturn True if a parameter path corresponds to trainable LoRA/connector weights.
Source code in skyrl/tx/models/types.py:20-26
def is_lora_param(self, path: tuple, _value) -> bool:
"""Return True if a parameter path corresponds to trainable LoRA/connector weights."""
is_lora = any(name in path for name in ("lora_A", "lora_B"))
is_connector = self.config.mhc_expansion_rate > 1 and any(
name in path for name in ("attn_connector", "mlp_connector")
)
return is_lora or is_connectorclass CausalLMOutput
CausalLMOutput(last_hidden_state: jax.Array, kv_cache: KVCache | None, hidden_states: list[jax.Array] | None = None) -> NoneOutput type for causal language models like Qwen3ForCausalLM.
Attributes:
| Name | Type | Description |
|---|---|---|
last_hidden_state | Array | The last hidden state from the model. |
kv_cache | KVCache | None |
hidden_states | list[Array] | None |
Source code in skyrl/tx/models/types.py:45-58
@jax.tree_util.register_dataclass
@dataclass
class CausalLMOutput:
"""Output type for causal language models like Qwen3ForCausalLM.
Attributes:
last_hidden_state: The last hidden state from the model.
kv_cache: The updated key-value cache (None during training).
hidden_states: All hidden states, if output_hidden_states=True.
"""
last_hidden_state: jax.Array
kv_cache: KVCache | None
hidden_states: list[jax.Array] | None = Noneattr last_hidden_state
last_hidden_state: jax.Arrayattr kv_cache
kv_cache: KVCache | Noneattr hidden_states
hidden_states: list[jax.Array] | None = None