Mixture of LoRA experts.
A buyer's product handles many narrow tasks. The refund flagger, the PHI redactor, the tone classifier, and the lead scorer all share an inbox. Training one giant LoRA on the union is dilutive. Loading four LoRAs and switching at the application layer is fragile. The router learns the switch.
The shape of the problem
kolm artifacts are domain-specific by construction. The compile loop produces a refund-flagger artifact, a PHI-redactor artifact, and a tone-classifier artifact as three separately signed files. Each has its own K-score on its own eval pack. Each is small (30-50 MB), each is independently revocable, each carries its own audit trail.
At inference time, a single product surface needs to route the incoming prompt to the right expert. The simple option is application-layer routing: the buyer writes if "refund" in subject: use refund_flag.kolm. This works for two or three obvious lanes. It does not scale to ten lanes with overlapping vocabulary, where a tone-and-refund email needs both a tone score and a refund flag, and a PHI redactor needs to run first on anything from a healthcare sender.
The MoE router learns the routing function from labeled examples.
The architecture
The router is a one-layer MLP: prompt-embedding → hidden → expert logits.
class Router(nn.Module):
def __init__(self, hidden_size, n_experts, router_hidden=256):
super().__init__()
self.fc1 = nn.Linear(hidden_size, router_hidden)
self.act = nn.GELU()
self.fc2 = nn.Linear(router_hidden, n_experts)
def forward(self, h):
return self.fc2(self.act(self.fc1(h)))
The prompt embedding comes from the base model's mean-pooled last hidden state. The base is loaded read-only and shared across experts at inference; only the router is trained. A trained router weighs O(hidden_size * router_hidden + router_hidden * n_experts) parameters, ~1-5 MB on disk for typical kolm setups.
Top-1 versus top-k
| Mode | Inference cost | Lineage | When to use |
|---|---|---|---|
| top-1 (Switch) | 1 expert runs per input | Fedus 2022 | Experts are mutually exclusive (refund vs PHI vs tone) |
| top-k (Mixtral) | k experts run; outputs weighted by router probabilities | Jiang 2024 | Experts overlap (legal-finance, code-math); k=2 is the canonical choice |
The top-k aggregation is a weighted sum of expert outputs, weights from the router's softmax. The Mixtral 8x7B model runs k=2 of 8 experts per token; we apply the same idea at per-sequence granularity. Higher k buys quality at linear inference cost. Lower k is cheaper and easier to attribute (the receipt records exactly which expert ran).
The two auxiliary losses
Cross-entropy on the supervised routing label is necessary but not sufficient. Two extra terms keep training stable:
- Router z-loss (Lepikhin 2020). Penalizes large logit magnitudes via
(logsumexp(logits))2. Without it the softmax saturates: one expert hits probability ~1, gradients flow only to that expert, and the router stops learning. Weight 1e-3 is the published default. - Load-balance loss (Shazeer 2017). Penalizes the router for sending most inputs to a single expert. Computed as
N * sum_i (f_i * P_i), wheref_iis the empirical fraction of inputs routed to expert i andP_iis the mean router probability for expert i. Minimized when both are uniform across experts. Weight 1e-2 is the typical setting.
Both terms can be disabled by setting the weight to zero. The kolm default is z-loss on, load-balance on, because the labeled routing data is usually imbalanced (the buyer has 70% of one expert's class) and the auxiliary loss broadens the router's coverage without sacrificing accuracy on the majority class.
Routing granularity
Two choices: route at sequence level (one expert per input prompt) or token level (one expert per generated token). The kolm default is sequence; token-level is supported through the runtime adapter pool but adds two surfaces:
- Per-token forward passes through the base while sampling, so the router can re-decide between tokens. The multi-LoRA serving runtime already handles concurrent adapters; per-token routing is the logical extension.
- Per-token auxiliary loss aggregation, since one sequence visits multiple experts and the load-balance must sum over both batch and time dimensions.
For the typical kolm artifact (a domain expert), the domain rarely switches mid-utterance, and sequence-level is the right choice. The Mixtral and Switch Transformer recipes both did token-level because the experts were language-modeling sub-networks, not domain specializations; the unit of routing matched the unit of generation.
Inference: route, then forward
def route(router_ckpt, base_model, prompt, k=1):
base = AutoModel.from_pretrained(base_model)
router = Router(...); router.load_state_dict(ckpt['router_state_dict'])
h = mean_pool(base(prompt))
probs = softmax(router(h))
return [(name, p) for p, name in topk(probs, k)]
The application layer reads the result, loads the chosen adapter from the registry (or from the in-process adapter pool), and runs inference. The receipt of the response records both the routing distribution (which expert won, and the runners-up) and the CID of the expert that actually ran.
What ships
The trained router lands as a small *.router.kolm artifact alongside the existing expert artifacts. The manifest pins:
- The base model CID (shared across experts and consulted for embedding).
- The expert names and their CIDs (so the binder can re-fetch each expert and re-verify its independent signature).
- The router weights (sub-MB) and the auxiliary-loss weights used during training.
- Held-out routing accuracy and the per-expert load distribution from the eval pack.
None of the experts are altered. Each one is still a separately signed file with its own receipt; the router is a thin layer on top that does not affect any individual expert's audit trail.
What the receipt records
"router": {
"method": "moe_lora_router",
"base_model": "Qwen/Qwen2.5-3B-Instruct",
"experts": {
"refund_flag": "cidv1:sha256:8e1a...",
"phi_redactor": "cidv1:sha256:1bcf...",
"tone_classify": "cidv1:sha256:c4a9..."
},
"config": {
"routing": "top_1",
"k": 1,
"z_loss_weight": 1e-3,
"load_balance_weight": 1e-2,
"router_hidden": 256
},
"n_train_rows": 4200,
"n_eval_rows": 460,
"n_experts": 3,
"loss_final": 0.092,
"eval_accuracy": 0.943,
"papers": [
"arXiv:1701.06538",
"arXiv:2006.16668",
"arXiv:2101.03961",
"arXiv:2401.04088",
"arXiv:2310.18339"
]
}
The auditor can re-fetch each expert by CID, replay the eval set, and recompute the routing accuracy. The expert CIDs are the binding contract: a router signed against a set of CIDs only works with those CIDs; swapping in a different refund-flagger artifact invalidates the receipt.
Edge cases worth naming
Cold-start with one example per expert. A router needs enough labeled traffic per expert to learn the boundary. Below ~50 examples per expert, the load-balance loss tilts the router toward uniform and the held-out accuracy collapses. The remedy is upstream: capture more traffic, or run the simple application-layer rule (substring match) until the router catches up.
Adding an expert after training. The router's final layer has a fixed output dimension. Adding a new expert means retraining the router from scratch on the union of old labels plus new-expert labels; the old experts do not need to be retouched. This is cheap because the router is a few hundred kilobytes.
Drift. The router's training distribution can diverge from production traffic. The receipt records the routing distribution per evaluation; a captured-traffic monitor compares the distribution against the receipt and flags drift before it shows up as a quality regression. Eval-set drift covers the broader pattern.
Adversarial prompts. A prompt crafted to confuse the router (mixing refund-language and medical-language to force ambiguous routing) maps to whichever expert wins by a thin margin. For high-stakes domains the runtime can ignore the top-1 and run top-k with a margin-of-confidence check, refusing if the top-2 are within a threshold.
Where this fits in the kolm compile loop
The router is a layer above the existing artifact pipeline. Each expert is compiled normally (capture → train → K-score → sign). The router is a separate compile step that consumes the experts' CIDs and a routing-labeled capture. The deployment unit is the router plus its referenced experts; the runtime resolves CIDs at request time. The multi-LoRA serving page covers the runtime's adapter pool and concurrent execution.
Citations
Shazeer, N. et al. Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. arXiv:1701.06538, 2017. The original MoE layer and the load-balance auxiliary loss.
Lepikhin, D. et al. GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding. arXiv:2006.16668, 2020. Router z-loss.
Fedus, W., Zoph, B. & Shazeer, N. Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. arXiv:2101.03961, 2022. Top-1 routing, simplified.
Jiang, A. Q. et al. Mixtral of Experts. arXiv:2401.04088, 2024. Top-2 routing with weighted-sum aggregation.
Liu, X. et al. MoLE: Mixture of LoRA Experts. arXiv:2310.18339, 2023. The first published LoRA-MoE recipe with learnable temperature.
Wang, Y. et al. AdaMix: Mixture-of-Adapters for Parameter-Efficient Fine-Tuning. arXiv:2205.12410, 2022. The mixture-of-adapters precursor; informs the per-sequence routing choice here.