TX Models
Model loading and configuration for the JAX backend.
Model Configuration
class ModelConfig
ModelConfig(config: PretrainedConfig | dict, *, max_lora_adapters: int, max_lora_rank: int, shard_attention_heads: bool, loss_chunk_size: int = 0, gradient_checkpointing: bool = False, mhc_expansion_rate: int = 1)Bases: PretrainedConfig
Configuration for skyrl models with LoRA support.
Wraps a HuggingFace PretrainedConfig with additional parameters for Multi-LoRA training and tensor parallelism.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config | PretrainedConfig | dict | A HuggingFace PretrainedConfig object (e.g., from AutoConfig.from_pretrained()) | required |
max_lora_adapters | int | Maximum number of concurrent LoRA adapters | required |
max_lora_rank | int | Maximum rank for LoRA adapters | required |
shard_attention_heads | bool | Whether to shard attention across tensor parallel devices | required |
loss_chunk_size | int | Chunk size for cross-entropy loss computation (0 = no chunking) | 0 |
gradient_checkpointing | bool | Recompute activations during backward to save memory | False |
mhc_expansion_rate | int | mHC expansion rate. Connectors are trainable when this is > 1. | 1 |
Functions:
| Name | Description |
|---|---|
get_config | Return text_config when present, otherwise return this config. |
get_text_config | Return a wrapped config built from self.text_config. |
get_num_experts |
Attributes:
| Name | Type | Description |
|---|---|---|
rope_parameters | ||
max_lora_adapters | int | |
max_lora_rank | int | |
shard_attention_heads | bool | |
loss_chunk_size | int | |
gradient_checkpointing | bool | |
mhc_expansion_rate | int |
Source code in skyrl/tx/models/configs.py:6-83
class ModelConfig(PretrainedConfig):
"""Configuration for skyrl models with LoRA support.
Wraps a HuggingFace PretrainedConfig with additional parameters
for Multi-LoRA training and tensor parallelism.
Args:
config: A HuggingFace PretrainedConfig object (e.g., from AutoConfig.from_pretrained())
max_lora_adapters: Maximum number of concurrent LoRA adapters
max_lora_rank: Maximum rank for LoRA adapters
shard_attention_heads: Whether to shard attention across tensor parallel devices
loss_chunk_size: Chunk size for cross-entropy loss computation (0 = no chunking)
gradient_checkpointing: Recompute activations during backward to save memory
mhc_expansion_rate: mHC expansion rate. Connectors are trainable when this is > 1.
"""
# Type hints for config attributes
max_lora_adapters: int
max_lora_rank: int
shard_attention_heads: bool
loss_chunk_size: int
gradient_checkpointing: bool
mhc_expansion_rate: int
def __init__(
self,
config: PretrainedConfig | dict,
*,
max_lora_adapters: int,
max_lora_rank: int,
shard_attention_heads: bool,
loss_chunk_size: int = 0,
gradient_checkpointing: bool = False,
mhc_expansion_rate: int = 1,
):
super().__init__(**(config if isinstance(config, dict) else config.__dict__))
# In transformers v5, rope_parameters may not contain rope_theta
# even when it exists as a top-level config attribute (e.g. DeepSeek v3).
# Inject it so model code can always use config.rope_parameters["rope_theta"].
rope_params = getattr(self, "rope_parameters", None) or {}
if "rope_theta" not in rope_params:
rope_theta = getattr(self, "rope_theta", None)
if rope_theta is not None:
rope_params["rope_theta"] = rope_theta
if rope_params:
self.rope_parameters = rope_params
self.max_lora_adapters = max_lora_adapters
self.max_lora_rank = max_lora_rank
self.shard_attention_heads = shard_attention_heads
self.loss_chunk_size = loss_chunk_size
self.gradient_checkpointing = gradient_checkpointing
self.mhc_expansion_rate = mhc_expansion_rate
def get_config(self) -> PretrainedConfig:
"""Return `text_config` when present, otherwise return this config."""
return self.get_text_config() if hasattr(self, "text_config") else self
def get_text_config(self, decoder=None, encoder=None) -> "ModelConfig":
"""Return a wrapped config built from `self.text_config`."""
text_cfg = super().get_text_config(decoder=decoder, encoder=encoder)
if text_cfg is self or isinstance(text_cfg, ModelConfig):
return text_cfg
return type(self)(
text_cfg,
max_lora_adapters=self.max_lora_adapters,
max_lora_rank=self.max_lora_rank,
shard_attention_heads=self.shard_attention_heads,
loss_chunk_size=self.loss_chunk_size,
gradient_checkpointing=self.gradient_checkpointing,
mhc_expansion_rate=self.mhc_expansion_rate,
)
def get_num_experts(self):
# TODO: Change this if there can be different numbers of experts in text_config and vision_config
config = self.get_config()
return getattr(config, "num_experts", None) or getattr(config, "n_routed_experts", None)attr rope_parameters
rope_parameters = rope_paramsattr max_lora_adapters
max_lora_adapters: int = max_lora_adaptersattr max_lora_rank
max_lora_rank: int = max_lora_rankattr shard_attention_heads
shard_attention_heads: bool = shard_attention_headsattr loss_chunk_size
loss_chunk_size: int = loss_chunk_sizeattr gradient_checkpointing
gradient_checkpointing: bool = gradient_checkpointingattr mhc_expansion_rate
mhc_expansion_rate: int = mhc_expansion_ratemethod get_config
get_config() -> PretrainedConfigReturn text_config when present, otherwise return this config.
Source code in skyrl/tx/models/configs.py:61-63
def get_config(self) -> PretrainedConfig:
"""Return `text_config` when present, otherwise return this config."""
return self.get_text_config() if hasattr(self, "text_config") else selfmethod get_text_config
get_text_config(decoder = None, encoder = None) -> ModelConfigReturn a wrapped config built from self.text_config.
Source code in skyrl/tx/models/configs.py:65-78
def get_text_config(self, decoder=None, encoder=None) -> "ModelConfig":
"""Return a wrapped config built from `self.text_config`."""
text_cfg = super().get_text_config(decoder=decoder, encoder=encoder)
if text_cfg is self or isinstance(text_cfg, ModelConfig):
return text_cfg
return type(self)(
text_cfg,
max_lora_adapters=self.max_lora_adapters,
max_lora_rank=self.max_lora_rank,
shard_attention_heads=self.shard_attention_heads,
loss_chunk_size=self.loss_chunk_size,
gradient_checkpointing=self.gradient_checkpointing,
mhc_expansion_rate=self.mhc_expansion_rate,
)method get_num_experts
get_num_experts()Source code in skyrl/tx/models/configs.py:80-83
def get_num_experts(self):
# TODO: Change this if there can be different numbers of experts in text_config and vision_config
config = self.get_config()
return getattr(config, "num_experts", None) or getattr(config, "n_routed_experts", None)Model Interface
class ModelForCausalLM
Functions:
| Name | Description |
|---|---|
get_model_config | |
get_decode_layers | Return pre-extracted per-layer parameters for decode. |
is_lora_param | Return True if a parameter path corresponds to trainable LoRA/connector weights. |
Attributes:
| Name | Type | Description |
|---|---|---|
config | ModelConfig |
Source code in skyrl/tx/models/types.py:13-36
class ModelForCausalLM:
config: ModelConfig
def get_model_config(self) -> ModelConfig:
return self.config
def get_decode_layers(self):
"""Return pre-extracted per-layer parameters for decode.
Called once outside the while_loop; the result is passed as the
``decode_layers`` argument to every decode-step ``model(...)`` call.
Override in subclasses that benefit from hoisting work out of the
loop (e.g. pre-extracting stacked layer parameters).
"""
return None
def is_lora_param(self, path: tuple, _value) -> bool:
"""Return True if a parameter path corresponds to trainable LoRA/connector weights."""
is_lora = any(name in path for name in ("lora_A", "lora_B"))
is_connector = self.config.mhc_expansion_rate > 1 and any(
name in path for name in ("attn_connector", "mlp_connector")
)
return is_lora or is_connectorattr config
config: ModelConfigmethod get_model_config
get_model_config() -> ModelConfigSource code in skyrl/tx/models/types.py:17-18
def get_model_config(self) -> ModelConfig:
return self.configmethod get_decode_layers
get_decode_layers()Return pre-extracted per-layer parameters for decode.
Called once outside the while_loop; the result is passed as the
decode_layers argument to every decode-step model(...) call.
Override in subclasses that benefit from hoisting work out of the
loop (e.g. pre-extracting stacked layer parameters).
Source code in skyrl/tx/models/types.py:20-28
def get_decode_layers(self):
"""Return pre-extracted per-layer parameters for decode.
Called once outside the while_loop; the result is passed as the
``decode_layers`` argument to every decode-step ``model(...)`` call.
Override in subclasses that benefit from hoisting work out of the
loop (e.g. pre-extracting stacked layer parameters).
"""
return Nonemethod is_lora_param
is_lora_param(path: tuple, _value: tuple) -> boolReturn True if a parameter path corresponds to trainable LoRA/connector weights.
Source code in skyrl/tx/models/types.py:30-36
def is_lora_param(self, path: tuple, _value) -> bool:
"""Return True if a parameter path corresponds to trainable LoRA/connector weights."""
is_lora = any(name in path for name in ("lora_A", "lora_B"))
is_connector = self.config.mhc_expansion_rate > 1 and any(
name in path for name in ("attn_connector", "mlp_connector")
)
return is_lora or is_connectorclass CausalLMOutput
CausalLMOutput(last_hidden_state: jax.Array, kv_cache: KVCache | None, hidden_states: list[jax.Array] | None = None) -> NoneOutput type for causal language models like Qwen3ForCausalLM.
Attributes:
| Name | Type | Description |
|---|---|---|
last_hidden_state | Array | The last hidden state from the model. |
kv_cache | KVCache | None | The updated key-value cache (None during training). |
hidden_states | list[Array] | None | All hidden states, if output_hidden_states=True. |
Source code in skyrl/tx/models/types.py:55-68
@jax.tree_util.register_dataclass
@dataclass
class CausalLMOutput:
"""Output type for causal language models like Qwen3ForCausalLM.
Attributes:
last_hidden_state: The last hidden state from the model.
kv_cache: The updated key-value cache (None during training).
hidden_states: All hidden states, if output_hidden_states=True.
"""
last_hidden_state: jax.Array
kv_cache: KVCache | None
hidden_states: list[jax.Array] | None = Noneattr last_hidden_state
last_hidden_state: jax.Arrayattr kv_cache
kv_cache: KVCache | Noneattr hidden_states
hidden_states: list[jax.Array] | None = None