SkyRL
Tinker API

Multi-tenancy

A single SkyRL Tinker server can host multiple LoRA adapters concurrently against a shared base model. Each adapter is its own Tinker model_id and its own client session — multiple tinker-cookbook recipes can train and sample in parallel without spinning up a separate server per workload.

This page describes the design, the operator contract, and quickstarts for SFT (sl_loop.py) and RL (rl_loop.py).

Multi-tenancy is wired on the Megatron backend with vLLM serving per-tenant adapters. FSDP multi-tenancy and multi-tenant full-parameter fine-tuning are not yet supported — see Limitations.

How it works

The base model is loaded once on the policy workers and shared across all tenants. Each tenant gets a per-adapter slot in pinned CPU memory holding its LoRA params, optimizer state, and step count; the live GPU adapter is swapped on demand at the top of every per-model dispatch entry point. Clients never reason about which adapter is currently resident — they just call the Tinker API with their model_id.

What this means for you:

  • GPU memory is bounded by the base model plus a few small LoRA buffers, regardless of tenant count. The growth from adding a tenant is in CPU memory (one slot per adapter, on the order of ~3× lora_param_bytes_per_DP_shard — tens of MB for Qwen3-0.6B at rank 32).
  • Swap cost is small relative to a forward pass — a host→device tensor.copy_() plus a DP-group barrier. You should not see noticeable per-call latency from tenant churn.
  • Per-tenant sampling on vLLM is by model_id. The worker exports each tenant's adapter into lora_sync_path/<model_id>/ on save_weights_for_sampler and registers it on vLLM via load_lora_adapter. Sampling uses model=<model_id> and vLLM routes to the right adapter.
  • Capacity is bounded by max_cpu_loras, vLLM's CPU LRU cache. If you have more concurrent tenants than slots, vLLM evicts one and the next sample() against it 404s — there is no on-demand reload. Size for your peak.

Operator contract

Required --backend-config keys to run multi-tenant LoRA on Megatron:

{
    "trainer.placement.colocate_all": false,
    "trainer.policy.megatron_config.lora_config.merge_lora": false,
    "trainer.policy.model.lora.max_loras": <max concurrent adapters in a single batch>,
    "trainer.policy.model.lora.max_cpu_loras": <total adapter capacity>
}

All adapters must share the same (rank, alpha, target_modules) signature. Mismatches are hard-rejected at create_model with a LoRA signature mismatch … error.

The first create_model on a fresh server triggers the policy build and bootstraps the per-tenant adapter slot infrastructure; subsequent create_model calls register additional adapter slots and complete in milliseconds. When the last registered model is unloaded the server tears down the Ray runtime via ray.shutdown(); the next create_model rebuilds it.

Quickstart — Two SL clients

Run two tinker-cookbook sl_loop clients in parallel against one Megatron-backed Tinker server.

1. Start the server

uv run --extra tinker --extra megatron -m skyrl.tinker.api \
    --host 0.0.0.0 \
    --port 8000 \
    --base-model Qwen/Qwen3-0.6B \
    --backend megatron \
    --backend-config '{
        "strategy": "megatron",
        "trainer.placement.policy_num_gpus_per_node": 1,
        "trainer.placement.policy_num_nodes": 1,
        "trainer.placement.colocate_all": false,
        "trainer.policy.megatron_config.tensor_model_parallel_size": 1,
        "trainer.policy.megatron_config.pipeline_model_parallel_size": 1,
        "trainer.policy.megatron_config.lora_config.merge_lora": false,
        "trainer.policy.model.lora.max_loras": 2,
        "trainer.policy.model.lora.max_cpu_loras": 2,
        "trainer.logprobs_chunk_size": null
    }'

Wait for init policy model done after the first client connects.

2. Run two sl_loop clients

In two separate terminals (in the tinker-cookbook repo):

# Terminal 2 — client A
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets \
    python -m tinker_cookbook.recipes.sl_loop \
    base_url=http://localhost:8000 \
    model_name="Qwen/Qwen3-0.6B" \
    train_on_what=LAST_ASSISTANT_MESSAGE \
    lora_rank=32 \
    log_path=/tmp/sl_loop_a.log
# Terminal 3 — client B
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets \
    python -m tinker_cookbook.recipes.sl_loop \
    base_url=http://localhost:8000 \
    model_name="Qwen/Qwen3-0.6B" \
    train_on_what=LAST_ASSISTANT_MESSAGE \
    lora_rank=32 \
    log_path=/tmp/sl_loop_b.log

Stagger the launches by ~20s so the second client doesn't race the policy build. Both clients must use the same lora_rank and model_name.

You should see both clients converge on their respective tasks, with NLL trending independently downward in both sl_loop_a.log and sl_loop_b.log. GPU memory will stay bounded even as the second client connects (single base model + N LoRA slots).

Quickstart — Two RL clients

Two rl_loop clients each train and sample independently against one server. RL exercises the per-tenant save_weights_for_sampler + sample(model=<model_id>) path.

1. Start the server

uv run --extra tinker --extra megatron -m skyrl.tinker.api \
    --host 0.0.0.0 \
    --port 8000 \
    --base-model Qwen/Qwen3-0.6B \
    --backend megatron \
    --backend-config '{
        "strategy": "megatron",
        "trainer.placement.policy_num_gpus_per_node": 4,
        "trainer.placement.policy_num_nodes": 1,
        "trainer.placement.colocate_all": false,
        "trainer.policy.megatron_config.tensor_model_parallel_size": 1,
        "trainer.policy.megatron_config.pipeline_model_parallel_size": 1,
        "trainer.policy.megatron_config.lora_config.merge_lora": false,
        "trainer.micro_train_batch_size_per_gpu": 64,
        "trainer.micro_forward_batch_size_per_gpu": 64,
        "generator.inference_engine.num_engines": 1,
        "generator.inference_engine.tensor_parallel_size": 1,
        "trainer.policy.model.lora.max_loras": 2,
        "trainer.policy.model.lora.max_cpu_loras": 2,
        "trainer.logprobs_chunk_size": null
    }'

Critical knobs vs the SL quickstart:

  • colocate_all: false is required. In order for sampling and training to progress independently for different client calls, inference engines and trainer workers should be placed on different GPUs.
  • merge_lora: false is required. With merge_lora: true, vLLM serves the merged base model and sample(model=<adapter>) returns the wrong tenant's weights.
  • max_loras ≥ number of adapters in a single batch (typically equal to the client count).
  • max_cpu_loras must be ≥ the number of adapters you expect to serve concurrently. There is no on-demand reload — if vLLM evicts an adapter, its next sample() 404s.

2. Run two rl_loop clients

# Terminal 2 — client A
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets --with torch \
    python -m tinker_cookbook.recipes.rl_loop \
    base_url=http://localhost:8000 \
    model_name="Qwen/Qwen3-0.6B" \
    lora_rank=32 \
    log_path=/tmp/rl_loop_a.log
# Terminal 3 — client B
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets --with torch \
    python -m tinker_cookbook.recipes.rl_loop \
    base_url=http://localhost:8000 \
    model_name="Qwen/Qwen3-0.6B" \
    lora_rank=32 \
    log_path=/tmp/rl_loop_b.log

Stagger by ~20 s. Both clients must use the same lora_rank and model_name.

You should see both clients' rewards trend upward independently in rl_loop_a.log and rl_loop_b.log, vLLM logs showing two distinct adapter names registered and sample requests routed to each., and GPU memory staying bounded (single base model, two LoRA adapters, CPU LRU holds the same two).

Troubleshooting

  • LoRA signature mismatch — clients passed different (rank, alpha, target_modules). All adapters on one server share a signature, captured from the first create_model.
  • sample() 404 on lora_name=… — either save_sampler_checkpoint wasn't called for that model_id before sampling, or max_cpu_loras is too low and vLLM evicted the adapter. Check the vLLM server log.
  • Server hangs on the second create_model — the first policy build hasn't finished. Wait for init policy model done before starting subsequent clients.
  • CPU OOM on the Nth client — each adapter slot holds LoRA params + fp32 main + Adam moments, roughly ~3× lora_param_bytes_per_DP_shard. For Qwen3-0.6B at rank 32 this is on the order of tens of MB per slot; for larger models scale accordingly. Reduce concurrent adapters or move to a host with more RAM.
  • Sample returns the wrong tenant's output — confirm merge_lora: false is set on the Megatron config; with merge enabled vLLM only sees the merged base.

On this page