Driver Layers¶
Overview¶
The driver layer implements the reservoir's temporal dynamics. It takes the embedded input from the embedding layer and the current reservoir state, then computes the next reservoir state according to the specific reservoir dynamics (e.g., Echo State Network equations).
All driver layers inherit from the DriverBase abstract class, ensuring a consistent API across different implementations.
Base Class: DriverBase¶
The DriverBase class defines the core interface that all driver implementations must follow:
Key Attributes¶
res_dim: Reservoir dimensiondtype: JAX array dtype (jnp.float32 or jnp.float64)
Core Methods¶
advance(proj_vars, res_state): Advance single reservoir state by one time stepbatch_advance(proj_vars, res_state): Advance batch of reservoir states by one time step__call__(proj_vars, res_state): Flexible interface supporting both single and batch updates
ParallelESNDriver¶
The ParallelESNDriver class implements standard Echo State Network dynamics with tanh nonlinearity and support for both discrete and continuous time modes.
Basic Usage¶
import jax.numpy as jnp
from orc.drivers import ParallelESNDriver
# Create a basic ESN driver
driver = ParallelESNDriver(
res_dim=100, # Reservoir dimension
leak=0.6, # Leak rate (0 < leak <= 1)
spectral_radius=0.8, # Spectral radius of reservoir matrix
density=0.02, # Connection density (0 < density <= 1)
bias=1.6, # Bias term in tanh nonlinearity
seed=42 # Random seed for reproducibility
)
# Single time step advance
proj_input = jnp.ones((1, 100)) # Shape: (chunks, res_dim)
res_state = jnp.zeros((1, 100)) # Shape: (chunks, res_dim)
next_state = driver(proj_input, res_state)
Discrete vs Continuous Time Modes¶
ParallelESNDriver supports both discrete and continuous time dynamics:
Discrete Mode (Default)¶
driver = ParallelESNDriver(
res_dim=100,
mode="discrete", # Default mode
leak=0.6,
seed=42
)
# Updates: x[n+1] = leak * tanh(A x[n] + proj_vars[n] + bias) + (1-leak) * x[n]
Continuous Mode¶
driver = ParallelESNDriver(
res_dim=100,
mode="continuous",
time_const=50.0, # Time constant τ
seed=42
)
# Updates: dx/dt = τ * (-x + tanh(W * x + proj_vars + bias))
Parallel Reservoirs¶
ParallelESNDriver supports multiple parallel reservoirs for processing spatial or high-dimensional data:
# Multiple parallel reservoirs
driver = ParallelESNDriver(
res_dim=500, # Dimension per reservoir
chunks=10, # Number of parallel reservoirs
spectral_radius=0.9,
density=0.05,
seed=42
)
# Input and state shapes: (chunks, res_dim)
proj_input = jnp.random.normal(jax.random.key(0), (10, 500))
res_state = jnp.zeros((10, 500))
next_state = driver(proj_input, res_state) # Shape: (10, 500)
Key Parameters¶
Reservoir Matrix Parameters¶
- spectral_radius: Controls stability and memory (typically < 1)
- density: Sparsity of connections
- chunks: Number of parallel reservoirs
Nonlinearity Parameters¶
- bias: Shifts tanh activation (affects reservoir dynamics)
- leak: Memory retention in discrete mode (0 < leak ≤ 1)
- time_const: Time scale in continuous mode (τ > 0)
Memory and Performance Options¶
Sparse Eigenvalue Computation¶
driver = ParallelESNDriver(
res_dim=1000,
use_sparse_eigs=True, # Use sparse eigensolvers (default)
eigenval_batch_size=100, # Batch eigenvalue computation
seed=42
)
Memory Management¶
- use_sparse_eigs=True: Uses sparse eigensolvers for large reservoirs (default)
- eigenval_batch_size: Processes eigenvalues in batches to reduce memory usage
- Automatic fallback to dense eigensolvers for small reservoirs (res_dim < 100)
Custom Drivers¶
To create your own driver layer:
from orc.drivers import DriverBase
import equinox as eqx
class CustomDriver(DriverBase):
def __init__(self, res_dim, **kwargs):
super().__init__(res_dim)
# Initialize your parameters
def advance(self, proj_vars, res_state):
# Implement your reservoir dynamics
# Return shape depends on parallel reservoir support
pass
Requirements¶
- Inherit from
DriverBase - Implement the abstract
advance()method - Choose your parallel reservoir support level:
Single Reservoir Only¶
For drivers that don't support parallel reservoirs:
def advance(self, proj_vars, res_state):
# proj_vars shape: (res_dim,)
# res_state shape: (res_dim,)
next_state = your_dynamics(proj_vars, res_state)
return next_state # Shape: (res_dim,)
Parallel Reservoir Support¶
For drivers that support parallel reservoirs like ParallelESNDriver:
def advance(self, proj_vars, res_state):
# proj_vars shape: (chunks, res_dim)
# res_state shape: (chunks, res_dim)
next_state = your_parallel_dynamics(proj_vars, res_state)
return next_state # Shape: (chunks, res_dim)
def __call__(self, proj_vars, res_state):
# Override to handle both single and batch inputs
# Handle shapes: (chunks, res_dim) and (seq_len, chunks, res_dim)
pass
Shape Summary¶
- Single reservoir:
advance()expects/returns(res_dim,) - Parallel reservoirs:
advance()expects/returns(chunks, res_dim)