Skip to content

Training & fine-tuning

How to fine-tune a served model with LoRA adapters: the entry points, the data formats, every config flag, and the methodology behind the knobs.

mlx-bun trains LoRA adapters (the base quantized weights stay frozen), supports SFT and DPO, and runs on a single Apple-Silicon GPU. The output is a PEFT-compatible adapter you hot-swap into the server — see adapters-end-to-end for the serving side and segmented-backward-training for the long-context memory mechanism.

Source of truth: the config schema is FinetuneSubmit in src/train/job.ts; defaults are DEFAULT_TRAIN_CONFIG in src/train/trainer.ts. This doc is generated against those — if they drift, the code wins.

There is no mlx-bun train CLI verb. Training runs as a subprocess job, reachable four ways:

PathHowUse when
Web UImlx-bun serve, open /finetune — pick model → dataset → hyperparameters → train; watch live train/val loss; merge/export the adapterInteractive, the default
HTTP APIPOST /api/finetune/submit (job id + SSE events); POST /api/finetune/inspect-dataset to probe a file; POST /api/finetune/merge to fold an adapter into base weightsScripted / remote
Scriptscripts/chunk-finetune.ts — an env-driven wrapper that calls the runner directlyRepeatable CLI runs
Shell recipescripts/ft-e4b-v2.sh probe|train — the actual e4b run we use; sets the required env (see What we actually run)Reproducing our e4b fine-tune
Libraryimport { finetuneRunner } from "./src/train/job" and call it with a config + emitterEmbedding training in your own TS
Terminal window
curl -s localhost:8090/api/finetune/submit -X POST -H 'content-type: application/json' -d '{
"model_dir": "/path/to/snapshot",
"data_dir": "/path/to/dataset", // dir with train.jsonl (+ optional valid.jsonl)
"adapter_path": "/path/to/output-adapter",
"method": "sft",
"rank": 16, "iters": 300, "learning_rate": 2e-4, "max_seq_length": 2048
}'
# → { "jobId": "..." } then stream events from the jobs SSE endpoint

model_dir, data_dir, and adapter_path are required; everything else falls back to the defaults below.

scripts/chunk-finetune.ts is the worked example (MiniCPM5 on the chunking task). It calls finetuneRunner directly, driven by env vars:

Terminal window
# MODEL unset → defaults to the MiniCPM5-1B-OptiQ-4bit snapshot
DATA=/path/to/chunk ITERS=300 RANK=16 SEQ=2048 SEG=4 \
bun scripts/chunk-finetune.ts

Env knobs — note the script applies its own task-tuned defaults, which differ from the trainer/API defaults in the table below:

EnvMaps toScript default
MODELmodel_dirMiniCPM5-1B-OptiQ-4bit snapshot (a path if set, not a name)
DATAdata_dirthe lucien chunk dataset path
SEQmax_seq_length8192
ITERSiters2 (probe; use 300 for a real run)
RANKrank16
LRlearning_rate1e-5
SCALEscale20
SEGsegment_size0 (off)
EVAL_EVERYsteps_per_evalauto from ITERS
ADAPTERadapter_path~/.cache/mlx-bun-finetunes/minicpm5-chunk-seq<SEQ>
CKPTsave_checkpointson (CKPT=0 disables)
GRAD_CKPTgrad_checkpointoff (GRAD_CKPT=1 enables)

The script hard-codes method=sft, batch_size=1, steps_per_report=1, and uses the default ops.sdpa training attention (set MLX_BUN_TRAIN_ATTN=flash to override — but flash crashes e4b at multi-K; see Methodology).

Each row of train.jsonl (and optional valid.jsonl) is auto-detected by its keys (src/train/dataset.ts):

FormatShapeLoss boundary
messages{"messages": [{"role","content"}, …]}response-only — loss on the final turn, prompt = chat-template render of all prior turns
prompt-completion{"prompt": "...", "completion": "..."}loss on the completion only
text{"text": "..."}full-sequence (no prompt mask)
dpo (method=dpo){"prompt", "chosen", "rejected"}preference loss on chosen vs rejected

Probe a file before submitting: POST /api/finetune/inspect-dataset with {"path": "..."} returns { ok, n_train, n_valid, format }.

  • SFT (method: "sft", default) — supervised fine-tune; response-only cross-entropy. Default LR 2e-4. For instruction-following, formatting, task adaptation.
  • DPO (method: "dpo") — Direct Preference Optimization on chosen/rejected pairs; loss -log σ(β·((π_c − ref_c) − (π_r − ref_r))) with reference log-probs computed at LoRA scale 0. Default LR 5e-5. Tune with dpo_beta, dpo_warmup_iters, dpo_lr_schedule.

All fields optional except model_dir / data_dir / adapter_path. Defaults are DEFAULT_TRAIN_CONFIG (trainer.ts:89).

Field (API)TypeDefaultEffect
methodsft | dposftTraining objective (see above)
rankint ≥28LoRA rank per adapted linear
scalefloat >01.0LoRA α (effective update = α·BA)
rank_scalingconstant | by_bits | by_klby_bitsPer-layer rank policy (see Methodology)
target_modulesstring[]q,k,v,o,gate,up,down _projWhich linears get adapters
num_layersint-1-1 = all layers; N = last N only
itersint >0100Total training steps
learning_ratefloat >02e-4 (sft) / 5e-5 (dpo)AdamW LR
max_seq_lengthint >0512Truncate/pad sequences to this
batch_sizeint ≥11Rows per step (B=1 is the safe path; B>1 length-sorts + pads to 32)
grad_accumulation_stepsint ≥11Accumulate grads over N micro-steps
seedint0RNG for shuffling + LoRA init
steps_per_reportint >010Emit a train-loss metric every N steps
steps_per_evalint >050Eval on valid.jsonl every N steps
weight_decayfloat ≥00.01AdamW weight decay (β = [0.9, 0.999], fixed)
grad_checkpointboolfalseRecompute layer activations in backward (memory↔compute; bit-identical)
segment_sizeint0 (off)>0 enables segmented backward — layers per segment (see below)
save_checkpointsboolfalseSave every eval-step checkpoint + write metrics.json
dpo_betafloat >00.1DPO strength (dpo only)
dpo_warmup_itersint ≥00DPO LR warmup (dpo only)
dpo_lr_scheduleconstant | cosinecosineDPO LR schedule (dpo only)
Env varSet for trainingDefaultWhy
MLX_BUN_FUSED_GELU0 (required for Gemma)onThe fused GeGLU is a CustomKernel with no gradient (vjp) (fused-geglu-kernel.ts); the Gemma forward uses it (gemma4.ts:277), so a Gemma (e4b/12B/26B) backward fails unless it’s off. MiniCPM5 (Llama-arch SwiGLU) never hits it, so the .ts script’s MiniCPM5 default doesn’t need it.
MLX_BUN_PERF_KERNEL0 for trainingonThe fused quantized-decode kernel likewise has no vjp. It only fires at decode L=1 (rare in the L>1 training forward), but the e4b recipe sets 0 to be safe.
MLX_BUN_TRAIN_ATTNleave unsetunset → ops.sdpaDefault ops.sdpa is mlx’s fused flash-attention kernel — the correct, working path. flash selects a different hand-rolled custom kernel that crashes e4b at multi-K; do not set it.
MLX_BUN_MEM_LOG1 to profileoffPrint per-step peak/active/cache memory

Important: the trainer itself (trainer.ts) only reads MLX_BUN_TRAIN_ATTN and MLX_BUN_MEM_LOG — but the model forward it runs reads MLX_BUN_FUSED_GELU / MLX_BUN_PERF_KERNEL. The trainer does not disable those itself, so the caller must export them. The e4b recipe does this; if you train a Gemma model by hand, set MLX_BUN_FUSED_GELU=0 yourself.

Everything above is the full surface (what you can do). In practice the fine-tune we run is scripts/ft-e4b-v2.sh: e4b (gemma-4-e4b-it-OptiQ-4bit, pinned snapshot) on the lucien chunk-v2-500 curated set (450 train convs) through the segmented-backward trainer. It wraps chunk-finetune.ts with the e4b-required env and a two-step workflow:

Terminal window
scripts/ft-e4b-v2.sh probe # 2-iter memory/stability check (~1 min) — RUN FIRST
scripts/ft-e4b-v2.sh train # the real run (~900 iters ≈ 2 epochs, batch_size 1)
ITERS=750 SEQ=4096 SEG=1 scripts/ft-e4b-v2.sh train # override any knob inline

What the recipe pins (and why it differs from the bare defaults):

KnobRecipe valueWhy
modele4b OptiQ-4bit, pinned snapshot fcdb12d7…the validated e4b snapshot
datachunk-v2-500 (450 train convs)the curated chunk set
SEQ8192long context
SEG4segmented backward, 4 layers/segment — so 8K-ctx activations fit
RANK / SCALE / LR16 / 20 / 1e-5task-tuned
ITERS2 (probe) / 900 (train)~2 epochs over 450 examples
MLX_BUN_PERF_KERNEL / MLX_BUN_FUSED_GELU0 / 0required — the fused kernels have no vjp (see env table)
attentiondefault ops.sdpamlx’s fused flash kernel; not MLX_BUN_TRAIN_ATTN=flash (that one crashes e4b)

The two non-negotiables for e4b: segmented backward (SEG>0, so the long-context activations fit) and the fused kernels off (so the backward has gradients). Always run probe before train.

Adapters attach to the target linears; A is initialized uniform, B is zeros, so the adapted model equals the base model at step 0. Only A/B are differentiated — base quantized weights are frozen. Default targets are the 7 attention+MLP projections per block (q/k/v/o_proj, gate/up/down_proj), following Unsloth. See src/train/lora-params.ts.

  • constant — every target gets rank.
  • by_bits (default)rank × (bits / 4), clamped ≥2; gives wider adapters to lower-bit (optiq mixed-precision) layers. Needs the model’s per-layer bits map.
  • by_kl — scales by per-layer KL importance, clamped to [0.5×, 2×]; falls back to by_bits if no KL map is present.

Long-context memory: segmented backward vs gradient checkpointing

Section titled “Long-context memory: segmented backward vs gradient checkpointing”

At long max_seq_length, activation memory dominates. Two levers:

  • grad_checkpoint: true — recompute each layer’s activations during backward. Bit-identical; trades compute for memory.
  • segment_size: Nsegmented backward: run the layer stack forward detaching the residual stream into graph-free boundary leaves every N layers, then backprop segment-by-segment via mlx_vjp (cotangent passed directly, not a surrogate-loss value_and_grad, which leaked). Only one segment’s activations live at a time. This is the path to multi-K context; full mechanism, proofs, and measured peaks (e.g. MiniCPM5 10.91→3.29 GB @2048) are in segmented-backward-training.

Training attention kernel (MLX_BUN_TRAIN_ATTN)

Section titled “Training attention kernel (MLX_BUN_TRAIN_ATTN)”
  • default ops.sdpa — mlx’s fused SDPA; correct (0.00% vs autograd), O(L²) backward memory. Use this.
  • flash — opt-in O(L) memory path, but the hand-rolled dK kernel is slow and crashes e4b at multi-K (≥2K); do not use it for e4b LoRA training. Detail in segmented-backward-training §6.

A finished run writes a PEFT-compatible adapter directory:

  • adapters.safetensors — the lora_a / lora_b tensors
  • optiq_lora_config.json — mlx-bun/optiq adapter metadata (per-layer ranks)
  • adapter_config.json — PEFT-compatible config

When save_checkpoints: true, each eval step also writes checkpoints/step-<NNNNN>-val<loss>/ and a durable metrics.json (config, wall seconds, peak GB, final/best train+val loss, full val trajectory).

Serving the adapter: hot-swap it into a running server via the adapter API and select it per-request — see adapters-end-to-end and the adapter endpoints in server-api. Or fold it into the base weights with POST /api/finetune/merge.

  • Start at batch_size: 1 (the no-padding path); raise only with headroom.
  • OOM → lower max_seq_length, set segment_size (e.g. 2–4), or reduce rank / num_layers.
  • Set MLX_BUN_MEM_LOG=1 to watch per-step peak memory.
  • grad_accumulation_steps raises effective batch without the memory cost.