ORC¶
Example: Time-Series Classification with ESN¶
This notebook demonstrates how to use ORC's classifier module to classify time-series data using an Echo State Network (ESN). We generate synthetic time-series belonging to distinct classes, train an ESNClassifier, and evaluate its performance. This is a common task in areas like activity recognition, medical signal analysis, and sensor data classification.
As with other ORC modules, only jax, orc, and equinox are needed. We also import matplotlib for visualization.
import jax
import jax.numpy as jnp
import equinox as eqx
import matplotlib.pyplot as plt
import orc
Generate Synthetic Classification Data¶
We create a simple 3-class classification problem where each class is a sinusoidal time-series with a distinct frequency. A small amount of noise is added to each sample. This mimics the kind of structure found in real-world time-series classification tasks, where different classes exhibit different temporal patterns.
# Dataset parameters
data_dim = 3
n_classes = 3
seq_len = 200
n_train_per_class = 20
n_test_per_class = 10
key = jax.random.PRNGKey(0)
def generate_samples(key, n_per_class):
"""Generate sinusoidal sequences with class-specific frequencies."""
seqs, labels = [], []
t = jnp.linspace(0, 4 * jnp.pi, seq_len).reshape(-1, 1)
for class_idx in range(n_classes):
freq = (class_idx + 1) * 0.5
for _ in range(n_per_class):
key, subkey = jax.random.split(key)
noise = 0.05 * jax.random.normal(subkey, (seq_len, data_dim))
seq = jnp.sin(freq * t * jnp.arange(1, data_dim + 1)) + noise
seqs.append(seq)
labels.append(class_idx)
return jnp.stack(seqs), jnp.array(labels, dtype=jnp.int32), key
train_seqs, train_labels, key = generate_samples(key, n_train_per_class)
test_seqs, test_labels, key = generate_samples(key, n_test_per_class)
print(f"Training set: {train_seqs.shape} (n_samples, seq_len, data_dim)")
print(f"Test set: {test_seqs.shape}")
Training set: (60, 200, 3) (n_samples, seq_len, data_dim) Test set: (30, 200, 3)
Let's visualize one example from each class to see how the temporal patterns differ.
fig, axes = plt.subplots(1, n_classes, figsize=(14, 3), sharey=True)
class_names = ["Low freq", "Medium freq", "High freq"]
for i in range(n_classes):
idx = i * n_train_per_class
axes[i].plot(train_seqs[idx, :, 0])
axes[i].set_title(f"Class {i}: {class_names[i]}")
axes[i].set_xlabel("Time step")
axes[0].set_ylabel("Amplitude (channel 0)")
plt.tight_layout()
plt.show()
Initialize the ESN Classifier¶
The ESNClassifier follows the same Embedding $\to$ Driver $\to$ Readout architecture as other ORC models. To instantiate a classifier, we specify:
data_dim: the number of input channels per time-stepn_classes: the number of classification classesres_dim: the dimension of the reservoir (latent state)
The state_repr parameter controls how reservoir states are summarized into a feature vector for classification:
"final"(default): uses only the last reservoir state"mean": averages all reservoir states (after an optional spinup period)
res_dim = 500
classifier = orc.classifier.ESNClassifier(
data_dim=data_dim,
n_classes=n_classes,
res_dim=res_dim,
seed=42,
)
print(f"Reservoir dimension: {classifier.res_dim}")
print(f"Input dimension: {classifier.in_dim}")
print(f"Output dimension: {classifier.out_dim}")
print(f"State representation: {classifier.state_repr}")
Reservoir dimension: 500 Input dimension: 3 Output dimension: 3 State representation: final
Train the Classifier¶
Training is performed by orc.classifier.train_ESNClassifier. Under the hood, this:
- Forces each training sequence through the reservoir (in parallel via
jax.vmap) - Extracts a feature vector per sequence (final state or mean state)
- Solves ridge regression from features to one-hot encoded class labels
- Updates the readout weights
The beta parameter controls Tikhonov regularization.
trained_classifier = orc.classifier.train_ESNClassifier(
classifier,
train_seqs=train_seqs,
labels=train_labels,
beta=1e-6,
)
print("Training complete.")
Training complete.
Evaluate on Training Data¶
We can classify individual sequences using classify, which by default zero-initializes the reservoir state and returns a vector of class probabilities (via softmax). The predicted class is the one with the highest probability.
probs = jax.vmap(trained_classifier.classify)(train_seqs)
train_preds = jnp.array(jnp.argmax(probs, axis=1))
train_acc = jnp.mean(train_preds == train_labels)
print(f"Training accuracy: {train_acc:.1%}")
Training accuracy: 100.0%
Evaluate on Test Data¶
The real test of the classifier is on unseen data. We classify each test sequence and compute the accuracy.
probs = jax.vmap(trained_classifier.classify)(test_seqs)
test_preds = jnp.array(jnp.argmax(probs, axis=1))
test_acc = jnp.mean(test_preds == test_labels)
print(f"Test accuracy: {test_acc:.1%}")
Test accuracy: 100.0%
Let's look at the predicted probabilities for a few test samples to see how confident the classifier is.
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i, class_idx in enumerate(range(n_classes)):
sample_idx = class_idx * n_test_per_class # first test sample of each class
probs = trained_classifier.classify(test_seqs[sample_idx])
axes[i].bar(range(n_classes), probs, color=["#4C72B0", "#55A868", "#C44E52"])
axes[i].set_xticks(range(n_classes))
axes[i].set_xticklabels(class_names)
axes[i].set_ylim(0, 1)
axes[i].set_title(f"True: {class_names[class_idx]}")
axes[i].set_ylabel("Probability" if i == 0 else "")
plt.suptitle("Predicted Class Probabilities", y=1.02)
plt.tight_layout()
plt.show()
Using Mean State Representation¶
Instead of using only the final reservoir state for classification, we can average the reservoir states over the sequence. This is often more robust when the discriminative information is spread throughout the time-series rather than concentrated at the end — for example, when class identity is determined by frequency content or statistical properties of the entire signal. The spinup parameter discards early transient states before averaging, since the reservoir needs time to synchronize with the input.
classifier_mean = orc.classifier.ESNClassifier(
data_dim=data_dim,
n_classes=n_classes,
res_dim=res_dim,
seed=42,
state_repr="mean",
)
trained_mean = orc.classifier.train_ESNClassifier(
classifier_mean,
train_seqs=train_seqs,
labels=train_labels,
spinup=20,
beta=1e-6,
)
in_seq = test_seqs
res_state = None
spinup = 20
probs = eqx.filter_vmap(trained_mean.classify)(in_seq, res_state, spinup)
mean_preds = jnp.argmax(probs, axis=1)
mean_acc = jnp.mean(mean_preds == test_labels)
print(f"Test accuracy (mean state): {mean_acc:.1%}")
Test accuracy (mean state): 100.0%
Summary¶
This notebook demonstrated the core classification workflow in ORC:
- Initialize an
ESNClassifierwithorc.classifier.ESNClassifier - Train with
orc.classifier.train_ESNClassifier, which uses ridge regression for a closed-form solution - Classify new sequences with
classify(reservoir state defaults to zero if not provided)
Key parameters to tune:
res_dim: larger reservoirs can capture more complex dynamicsbeta: regularization strength (increase if overfitting)state_repr:"final"vs"mean"depending on where class-discriminative information lies in the sequencespinup: number of transient states to discard when usingstate_repr="mean"