Skip to content

Drivers API Reference

This module contains reservoir drivers that define the dynamics of the reservoir.

Driver Base Class

DriverBase

DriverBase(res_dim, dtype=float64)

Bases: Module, ABC

Base class dictating API for all implemented reservoir drivers.

Attributes:

Name Type Description
res_dim int

Reservoir dimensionxe

dtype Float

Dtype for model, jnp.float64 or jnp.float32.

Methods:

Name Description
advance

Advance reservoir according to proj_vars.

batch_advance

Advance batch of reservoir states according to proj_vars.

Ensure reservoir dim and dtype are correct type.

Source code in src/orc/drivers.py
def __init__(self, res_dim, dtype=jnp.float64):
    """Ensure reservoir dim and dtype are correct type."""
    self.res_dim = res_dim
    if not isinstance(res_dim, int):
        raise TypeError("Reservoir dimension res_dim must be an integer.")
    self.dtype = dtype
    if not (dtype == jnp.float64 or dtype == jnp.float32):
        raise TypeError("dtype must be jnp.float64 or jnp.float32.")

advance abstractmethod

advance(proj_vars: Array, res_state: Array) -> Array

Advance the reservoir given projected inputs and current state.

Parameters:

Name Type Description Default
proj_vars Array

Projected inputs to reservoir.

required
res_state Array

Initial reservoir state.

required

Returns:

Type Description
Array

Updated reservoir state, (shape=(res_dim,)).

Source code in src/orc/drivers.py
@abstractmethod
def advance(self, proj_vars: Array, res_state: Array) -> Array:
    """Advance the reservoir given projected inputs and current state.

    Parameters
    ----------
    proj_vars : Array
        Projected inputs to reservoir.
    res_state : Array
        Initial reservoir state.

    Returns
    -------
    Array
        Updated reservoir state, (shape=(res_dim,)).
    """
    pass

batch_advance

batch_advance(proj_vars: Array, res_state: Array) -> Array

Batch advance the reservoir given projected inputs and current state.

Parameters:

Name Type Description Default
proj_vars Array

Reservoir projected inputs.

required
res_state Array

Reservoir state.

required

Returns:

Type Description
Array

Updated reservoir state.

Source code in src/orc/drivers.py
@eqx.filter_jit
def batch_advance(self, proj_vars: Array, res_state: Array) -> Array:
    """
    Batch advance the reservoir given projected inputs and current state.

    Parameters
    ----------
    proj_vars : Array
        Reservoir projected inputs.
    res_state : Array
        Reservoir state.

    Returns
    -------
    Array
        Updated reservoir state.
    """
    return eqx.filter_vmap(self.advance)(proj_vars, res_state)

__call__

__call__(proj_vars: Array, res_state: Array) -> Array

Advance reservoir state.

Parameters:

Name Type Description Default
proj_vars Array

Reservoir projected inputs, (shape=(chunks, res_dim) or shape=(seq_len, chunks, res_dim)).

required
res_state Array

Current reservoir state, (shape=(chunks, res_dim) or shape=(seq_len, chunks, res_dim)).

required

Returns:

Type Description
Array

Sequence of reservoir states, (shape=(chunks, res_dim,) or shape=(seq_len, chunks, res_dim)).

Source code in src/orc/drivers.py
def __call__(self, proj_vars: Array, res_state: Array) -> Array:
    """Advance reservoir state.

    Parameters
    ----------
    proj_vars : Array
        Reservoir projected inputs, (shape=(chunks, res_dim) or
        shape=(seq_len, chunks, res_dim)).
    res_state : Array
        Current reservoir state, (shape=(chunks, res_dim) or
        shape=(seq_len, chunks, res_dim)).

    Returns
    -------
    Array
        Sequence of reservoir states, (shape=(chunks, res_dim,) or
        shape=(seq_len, chunks, res_dim)).
    """
    if self.chunks > 0:
        if len(proj_vars.shape) == 2:
            to_ret = self.advance(proj_vars, res_state)
        elif len(proj_vars.shape) == 3:
            to_ret = self.batch_advance(proj_vars, res_state)
    else:
        if len(proj_vars.shape) == 1:
            to_ret = self.advance(proj_vars, res_state)
        elif len(proj_vars.shape) == 2:
            to_ret = self.batch_advance(proj_vars, res_state)
    return to_ret

ESN Driver

ParallelESNDriver

ParallelESNDriver(res_dim: int, leak: float = 0.6, spectral_radius: float = 0.8, density: float = 0.02, bias: float = 1.6, dtype: Float = float64, chunks: int = 1, mode: str = 'discrete', time_const: float = 50.0, *, seed: int, use_sparse_eigs: bool = True, eigenval_batch_size: int = None)

Bases: DriverBase

Standard implementation of ESN reservoir with tanh nonlinearity.

Attributes:

Name Type Description
res_dim int

Reservoir dimension.

wr Array

Reservoir update matrix, (shape=(chunks, res_dim, res_dim,)).

leak float

Leak rate parameter.

spectral_radius float

Spectral radius of wr.

density float

Density of wr.

bias float

Additive bias in tanh nonlinearity.

chunks int

Number of parallel reservoirs.

mode str

Mode of reservoir update, either "discrete" or "continuous".

time_const float

Time constant for continuous mode.

dtype Float

Dtype, default jnp.float64.

Methods:

Name Description
advance

Updated reservoir state.

__call__

Batched or single update to reservoir state.

Initialize weight matrices.

Parameters:

Name Type Description Default
res_dim int

Reservoir dimension.

required
leak float

Leak rate parameter.

0.6
spectral_radius float

Spectral radius of wr.

0.8
density float

Density of wr.

0.02
bias float

Additive bias in tanh nonlinearity.

1.6
chunks int

Number of parallel reservoirs.

1
mode str

Mode of reservoir update, either "discrete" or "continuous".

'discrete'
time_const float

Time constant for continuous mode. Ignored in discrete mode.

50.0
dtype Float

Dtype, default jnp.float64.

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

required
use_sparse_eigs bool

Whether to use sparse eigensolver for setting the spectral radius of wr. Default is True, which is recommended to save memory and compute time. If False, will use dense eigensolver which may be more accurate.

True
eigenval_batch_size int

Size of batches when batch_eigenvals. Default is None, which means no batch eigenvalue computation.

None
Source code in src/orc/drivers.py
def __init__(
    self,
    res_dim: int,
    leak: float = 0.6,
    spectral_radius: float = 0.8,
    density: float = 0.02,
    bias: float = 1.6,
    dtype: Float = jnp.float64,
    chunks: int = 1,
    mode: str = "discrete",
    time_const: float = 50.0,
    *,
    seed: int,
    use_sparse_eigs: bool = True,
    eigenval_batch_size: int = None,
) -> None:
    """Initialize weight matrices.

    Parameters
    ----------
    res_dim : int
        Reservoir dimension.
    leak : float
        Leak rate parameter.
    spectral_radius : float
        Spectral radius of wr.
    density : float
        Density of wr.
    bias : float
        Additive bias in tanh nonlinearity.
    chunks: int
        Number of parallel reservoirs.
    mode : str
        Mode of reservoir update, either "discrete" or "continuous".
    time_const : float
        Time constant for continuous mode. Ignored in discrete mode.
    dtype : Float
        Dtype, default jnp.float64.
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    use_sparse_eigs : bool
        Whether to use sparse eigensolver for setting the spectral radius of wr.
        Default is True, which is recommended to save memory and compute time. If
        False, will use dense eigensolver which may be more accurate.
    eigenval_batch_size : int
        Size of batches when batch_eigenvals. Default is None, which means no
        batch eigenvalue computation.
    """
    super().__init__(res_dim=res_dim, dtype=dtype)
    self.res_dim = res_dim
    self.leak = leak
    self.spectral_radius = spectral_radius
    self.density = density
    self.bias = bias
    self.dtype = dtype
    self.mode = mode
    self.time_const = time_const
    key = jax.random.key(seed)
    if spectral_radius <= 0:
        raise ValueError("Spectral radius must be positive.")
    if leak < 0 or leak > 1:
        raise ValueError("Leak rate must satisfy 0 < leak < 1.")
    if density < 0 or density > 1:
        raise ValueError("Density must satisfy 0 < density < 1.")
    if mode not in ["discrete", "continuous"]:
        raise ValueError("Mode must be either 'discrete' or 'continuous'.")
    if time_const <= 0:
        raise ValueError("Time constant must be positive.")
    subkey, wr_key = jax.random.split(key)

    # check res_dim size for eigensolver choice
    if res_dim < 100 and use_sparse_eigs:
        use_sparse_eigs = False
        warnings.warn(
            "Reservoir dimension is less than 100, using dense "
            "eigensolver for spectral radius.",
            stacklevel=2,
        )

    # generate all wr matricies
    sp_mat = sparse.random_bcoo(
        key=wr_key,
        shape=(chunks, res_dim, res_dim),
        n_batch=1,
        nse=density,
        dtype=dtype,
        generator=jax.random.normal,
    )

    self.wr = _spec_rad_normalization(
        sp_mat,
        spectral_radius=spectral_radius,
        eigenval_batch_size=eigenval_batch_size,
        use_sparse_eigs=use_sparse_eigs,
        chunks=chunks,
    )
    self.chunks = chunks
    self.dtype = dtype

advance

advance(proj_vars: Array, res_state: Array) -> Array

Advance the reservoir state.

Parameters:

Name Type Description Default
proj_vars Array

Reservoir projected inputs, (shape=(chunks, res_dim,)).

required
res_state Array

Reservoir state, (shape=(chunks, res_dim,)).

required

Returns:

Name Type Description
res_next Array

Reservoir state, (shape=(chunks, res_dim,)).

Source code in src/orc/drivers.py
@eqx.filter_jit
def advance(self, proj_vars: Array, res_state: Array) -> Array:
    """Advance the reservoir state.

    Parameters
    ----------
    proj_vars : Array
        Reservoir projected inputs, (shape=(chunks, res_dim,)).
    res_state : Array
        Reservoir state, (shape=(chunks, res_dim,)).

    Returns
    -------
    res_next : Array
        Reservoir state, (shape=(chunks, res_dim,)).
    """
    if proj_vars.shape != (self.chunks, self.res_dim):
        raise ValueError(f"Incorrect proj_var dimension, got {proj_vars.shape}")
    if self.mode == "continuous":
        return self.time_const * (
            -res_state
            + _sparse_ops(
                self.wr, res_state, proj_vars, self.bias * jnp.ones_like(proj_vars)
            )
        )
    else:
        return (
            self.leak
            * _sparse_ops(
                self.wr, res_state, proj_vars, self.bias * jnp.ones_like(proj_vars)
            )
            + (1 - self.leak) * res_state
        )