JAX JIT Compatibility¶
TL;DR¶
Use eqx.filter_{vmap, grad, jit} over jax.{vmap, grad, jit} unless you know otherwise.
ORC models are Equinox modules, which means they are registered as JAX pytrees. This notebook covers the sharp bits around using jax.jit and eqx.filter_jit with ORC models and shows the correct pattern for each case.
Why eqx.filter_{jit, vmap, grad}¶
Many ORC layers use Jax sparse BCOO matrices to accelerate computations. BCOO matrices are unhashable. When you write jax.jit(esn.forecast), JAX closes over esn and tries to hash it as a static compile-time constant, which fails because it contains a BCOO.
eqx.filter_jit avoids this entirely by treating the model as a pytree. Array leaves (including BCOO) are traced, and non-array structural fields are treated as static. No hashing is required. Since all ORC inference and training methods are already decorated with @eqx.filter_jit, calling them directly compiles on the first call — you rarely need to add eqx.filter_jit explicitly.
When jax.jit works¶
If you need fine-grained control over which arguments are static for recompilation purposes, jax.jit can work provided you:
- Pass the model as an explicit pytree argument (not closed over), and
- Mark integer scalar arguments that appear as
jax.lax.scanloop lengths as static viastatic_argnums/static_argnames.
import functools
import equinox as eqx
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import orc
1. Forecasting¶
Train an ESNForecaster on the Lorenz-63 system.
U, t = orc.data.lorenz63(tN=30, dt=0.01)
split = int(0.8 * len(U))
U_train, U_test = U[:split], U[split:]
esn = orc.forecaster.ESNForecaster(data_dim=3, res_dim=500, seed=0)
esn, R = orc.forecaster.train_RCForecaster(esn, U_train)
fcast_len = len(U_test)
print(f"Reservoir states shape: {R.shape}")
Reservoir states shape: (2400, 1, 500)
Pattern 1 — closed-over jax.jit (does not work)¶
jax.jit(esn.forecast) closes over esn. JAX tries to hash the model as a compile-time static constant, which fails because wr is an unhashable BCOO sparse matrix.
try:
U_pred = jax.jit(esn.forecast)(fcast_len=fcast_len, res_state=R[-1])
except TypeError as e:
print(f"TypeError: {e}")
TypeError: unhashable type: 'BCOO'
Pattern 2 — eqx.filter_jit (recommended)¶
eqx.filter_jit treats the model as a pytree, tracing its array leaves and keeping non-array fields static. No hashing is ever attempted.
Because forecast is already decorated with @eqx.filter_jit internally, calling esn.forecast(...) directly compiles on first use — wrapping it again in eqx.filter_jit is equivalent and harmless.
# Calling the method directly compiles on first use (already @eqx.filter_jit)
U_pred = esn.forecast(fcast_len=fcast_len, res_state=R[-1])
print(f"Forecast shape: {U_pred.shape}")
# Explicit eqx.filter_jit wrap is equivalent:
U_pred = eqx.filter_jit(esn.forecast)(fcast_len=fcast_len, res_state=R[-1])
print(f"Forecast shape (explicit filter_jit): {U_pred.shape}")
Forecast shape: (600, 3)
Forecast shape (explicit filter_jit): (600, 3)
Pattern 3 — jax.jit with explicit pytree argument¶
Pass the model as an explicit argument so JAX traces through it as a pytree instead of hashing it. The fcast_len argument is used as a jax.lax.scan loop length, so it must be a concrete value at trace time — mark it static.
@functools.partial(jax.jit, static_argnums=(1,)) # fcast_len must be static
def jit_forecast(model, fcast_len, res_state):
return model.forecast(fcast_len, res_state)
U_pred = jit_forecast(esn, fcast_len, R[-1])
print(f"Forecast shape: {U_pred.shape}")
Forecast shape: (600, 3)
Training with jax.jit¶
train_RCForecaster and similar training functions work with jax.jit when the model is passed as an explicit pytree argument. Default scalar parameters (spinup=0, batch_size=None) remain as concrete Python values and do not need special handling. If you pass them explicitly, declare them static.
esn_fresh = orc.forecaster.ESNForecaster(data_dim=3, res_dim=500, seed=1)
# Default arguments — works directly
esn_trained, R_trained = jax.jit(orc.forecaster.train_RCForecaster)(esn_fresh, U_train)
print("Training with jax.jit (defaults): OK")
# Non-default spinup — must be declared static (appears in a Python conditional)
train_fn = jax.jit(
orc.forecaster.train_RCForecaster,
static_argnames=("spinup",),
)
esn_trained, R_trained = train_fn(esn_fresh, U_train, spinup=50)
print("Training with jax.jit (explicit spinup): OK")
Training with jax.jit (defaults): OK
Training with jax.jit (explicit spinup): OK
2. Classification¶
The same rules apply to ESNClassifier. Generate a simple multi-class sinusoidal dataset and train a classifier.
n_classes, data_dim, seq_len = 3, 3, 200
key = jax.random.PRNGKey(42)
def make_sequences(key, n_per_class):
t = jnp.linspace(0, 4 * jnp.pi, seq_len).reshape(-1, 1)
seqs, labels = [], []
for cls in range(n_classes):
freq = (cls + 1) * 0.5
for _ in range(n_per_class):
key, sub = jax.random.split(key)
noise = 0.05 * jax.random.normal(sub, (seq_len, data_dim))
seqs.append(jnp.sin(freq * t * jnp.arange(1, data_dim + 1)) + noise)
labels.append(cls)
return jnp.stack(seqs), jnp.array(labels, dtype=jnp.int32)
train_seqs, train_labels = make_sequences(key, 20)
test_seqs, test_labels = make_sequences(jax.random.PRNGKey(99), 10)
clf = orc.classifier.ESNClassifier(data_dim=data_dim, n_classes=n_classes, res_dim=500, seed=0)
clf = orc.classifier.train_RCClassifier(clf, train_seqs, train_labels)
print("Classifier trained.")
Classifier trained.
Pattern 1 — closed-over jax.jit (does not work)¶
try:
probs = jax.jit(clf.classify)(test_seqs[0])
except TypeError as e:
print(f"TypeError: {e}")
TypeError: unhashable type: 'BCOO'
Pattern 2 — eqx.filter_jit (recommended)¶
# Single sample
probs = clf.classify(test_seqs[0])
print(f"Predicted class: {jnp.argmax(probs)}, True class: {test_labels[0]}")
# Batch over test set with vmap
all_probs = jax.vmap(clf.classify)(test_seqs)
accuracy = jnp.mean(jnp.argmax(all_probs, axis=1) == test_labels)
print(f"Test accuracy: {accuracy:.1%}")
Predicted class: 0, True class: 0
Test accuracy: 100.0%
Pattern 3 — jax.jit with explicit pytree argument¶
classify has no integer loop-length arguments, so no static_argnums is needed — just pass the model explicitly.
@jax.jit
def jit_classify(model, in_seq):
return model.classify(in_seq)
probs = jit_classify(clf, test_seqs[0])
print(f"Predicted class: {jnp.argmax(probs)}, True class: {test_labels[0]}")
Predicted class: 0, True class: 0
3. Control¶
The same rules apply to ESNController. We use a simple 2D linear dynamical system as a plant: two states, one scalar control input.
# Simple stable 2D linear plant
A_plant = jnp.array([[0.95, 0.1], [-0.1, 0.95]])
B_plant = jnp.array([[0.1], [0.0]])
def simulate(u_seq, x0=None):
x = x0 if x0 is not None else jnp.zeros(2)
def step(x, u):
x_next = A_plant @ x + B_plant @ jnp.atleast_1d(u)
return x_next, x_next
_, states = jax.lax.scan(step, x, u_seq)
return states
key = jax.random.PRNGKey(0)
u_train = jax.random.normal(key, (600, 1)) * 0.3
train_outputs = simulate(u_train)
ctrl = orc.control.ESNController(data_dim=2, control_dim=1, res_dim=500, seed=0)
ctrl, R_ctrl = orc.control.train_RCController(ctrl, train_outputs[:-1], u_train[1:])
print(f"Controller trained. Reservoir states shape: {R_ctrl.shape}")
Controller trained. Reservoir states shape: (599, 500)
Pattern 1 — closed-over jax.jit (does not work)¶
control_guess = jnp.zeros((100, 1))
try:
output_traj = jax.jit(ctrl.apply_control)(control_guess, R_ctrl[-1])
except TypeError as e:
print(f"TypeError: {e}")
TypeError: unhashable type: 'BCOO'
Pattern 2 — eqx.filter_jit (recommended)¶
# apply_control is already @eqx.filter_jit
output_traj = ctrl.apply_control(control_guess, R_ctrl[-1])
print(f"Controlled trajectory shape: {output_traj.shape}")
Controlled trajectory shape: (100, 2)
Pattern 3 — jax.jit with explicit pytree argument¶
@jax.jit
def jit_apply_control(model, control_seq, res_state):
return model.apply_control(control_seq, res_state)
output_traj = jit_apply_control(ctrl, control_guess, R_ctrl[-1])
print(f"Controlled trajectory shape: {output_traj.shape}")
Controlled trajectory shape: (100, 2)