Skip to content

Classifier API Reference

The orc.classifier subpackage provides reservoir computing models for time series classification.

Base Classes

RCClassifierBase

RCClassifierBase(driver: DriverBase, readout: ReadoutBase, embedding: EmbedBase, n_classes: int, state_repr: str = 'final', dtype: Float = float64, seed: int = 0)

Bases: Module, ABC

Base class for reservoir computer classifiers.

Defines the interface for the reservoir computer classifier which includes the driver, readout and embedding layers. The classifier forces input sequences through the reservoir, extracts a feature vector from the reservoir states, and applies a trained readout to produce class probabilities.

Attributes:

Name Type Description
driver DriverBase

Driver layer of the reservoir computer.

readout ReadoutBase

Readout layer of the reservoir computer. Output dimension equals n_classes.

embedding EmbedBase

Embedding layer of the reservoir computer.

in_dim int

Dimension of the input data.

out_dim int

Dimension of the output data (equals n_classes).

res_dim int

Dimension of the reservoir.

n_classes int

Number of classification classes.

state_repr str

Reservoir state representation for classification. "final" uses the last reservoir state, "mean" averages states after spinup.

dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

seed int

Random seed for generating the PRNG key for the reservoir computer.

Methods:

Name Description
force

Teacher forces the reservoir with the input sequence.

classify

Classify an input sequence, returning class probabilities.

set_readout

Replaces the readout layer of the reservoir computer.

set_embedding

Replaces the embedding layer of the reservoir computer.

Initialize RCClassifier Base.

Parameters:

Name Type Description Default
driver DriverBase

Driver layer of the reservoir computer.

required
readout ReadoutBase

Readout layer of the reservoir computer.

required
embedding EmbedBase

Embedding layer of the reservoir computer.

required
n_classes int

Number of classification classes.

required
state_repr str

Reservoir state representation for classification. "final" uses the last reservoir state, "mean" averages states after spinup.

'final'
dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
Source code in src/orc/classifier/base.py
def __init__(
    self,
    driver: DriverBase,
    readout: ReadoutBase,
    embedding: EmbedBase,
    n_classes: int,
    state_repr: str = "final",
    dtype: Float = jnp.float64,
    seed: int = 0,
) -> None:
    """Initialize RCClassifier Base.

    Parameters
    ----------
    driver : DriverBase
        Driver layer of the reservoir computer.
    readout : ReadoutBase
        Readout layer of the reservoir computer.
    embedding : EmbedBase
        Embedding layer of the reservoir computer.
    n_classes : int
        Number of classification classes.
    state_repr : str
        Reservoir state representation for classification.
        "final" uses the last reservoir state, "mean" averages states after spinup.
    dtype : type
        Data type of the reservoir computer (jnp.float64 is highly recommended).
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    """
    if state_repr not in ("final", "mean"):
        raise ValueError(
            f"state_repr must be 'final' or 'mean', got '{state_repr}'."
        )
    self.driver = driver
    self.readout = readout
    self.embedding = embedding
    self.in_dim = self.embedding.in_dim
    self.out_dim = self.readout.out_dim
    self.res_dim = self.driver.res_dim
    self.n_classes = n_classes
    self.state_repr = state_repr
    self.dtype = dtype
    self.seed = seed

force

force(in_seq: Array, res_state: Array) -> Array

Teacher forces the reservoir with the input sequence.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to force the reservoir, (shape=(seq_len, in_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

required

Returns:

Type Description
Array

Forced reservoir sequence, (shape=(seq_len, res_dim)).

Source code in src/orc/classifier/base.py
@eqx.filter_jit
def force(self, in_seq: Array, res_state: Array) -> Array:
    """Teacher forces the reservoir with the input sequence.

    Parameters
    ----------
    in_seq : Array
        Input sequence to force the reservoir, (shape=(seq_len, in_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).

    Returns
    -------
    Array
        Forced reservoir sequence, (shape=(seq_len, res_dim)).
    """

    def scan_fn(state, in_vars):
        proj_vars = self.embedding.embed(in_vars)
        res_state = self.driver.advance(proj_vars, state)
        return (res_state, res_state)

    _, res_seq = jax.lax.scan(scan_fn, res_state, in_seq)
    return res_seq

classify

classify(in_seq: Array, res_state: Array | None = None, spinup: int = 0) -> Array

Classify an input sequence.

Forces the reservoir with the input sequence, extracts a feature vector from the reservoir states, and returns softmax class probabilities.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to classify, (shape=(seq_len, in_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

None
spinup int

Number of initial reservoir states to discard before extracting features. Only used when state_repr="mean".

0

Returns:

Type Description
Array

Class probabilities, (shape=(n_classes,)).

Source code in src/orc/classifier/base.py
@eqx.filter_jit
def classify(
    self, in_seq: Array, res_state: Array | None = None, spinup: int = 0
) -> Array:
    """Classify an input sequence.

    Forces the reservoir with the input sequence, extracts a feature vector
    from the reservoir states, and returns softmax class probabilities.

    Parameters
    ----------
    in_seq : Array
        Input sequence to classify, (shape=(seq_len, in_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).
    spinup : int
        Number of initial reservoir states to discard before extracting
        features. Only used when state_repr="mean".

    Returns
    -------
    Array
        Class probabilities, (shape=(n_classes,)).
    """
    if res_state is None:
        res_state = jnp.zeros(self.res_dim)

    res_seq = self.force(in_seq, res_state)

    if self.state_repr == "final":
        feature = res_seq[-1]
    else:  # "mean"
        feature = jnp.mean(res_seq[spinup:], axis=0)

    logits = self.readout.readout(feature)
    return jax.nn.softmax(logits)

__call__

__call__(in_seq: Array, res_state: Array, spinup: int = 0) -> Array

Classify an input sequence, wrapper for classify method.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to classify, (shape=(seq_len, in_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

required
spinup int

Number of initial reservoir states to discard before extracting features. Only used when state_repr="mean".

0

Returns:

Type Description
Array

Class probabilities, (shape=(n_classes,)).

Source code in src/orc/classifier/base.py
def __call__(self, in_seq: Array, res_state: Array, spinup: int = 0) -> Array:
    """Classify an input sequence, wrapper for `classify` method.

    Parameters
    ----------
    in_seq : Array
        Input sequence to classify, (shape=(seq_len, in_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).
    spinup : int
        Number of initial reservoir states to discard before extracting
        features. Only used when state_repr="mean".

    Returns
    -------
    Array
        Class probabilities, (shape=(n_classes,)).
    """
    return self.classify(in_seq, res_state, spinup)

set_readout

set_readout(readout: ReadoutBase) -> RCClassifierBase

Replace readout layer.

Parameters:

Name Type Description Default
readout ReadoutBase

New readout layer.

required

Returns:

Type Description
RCClassifierBase

Updated model with new readout layer.

Source code in src/orc/classifier/base.py
def set_readout(self, readout: ReadoutBase) -> "RCClassifierBase":
    """Replace readout layer.

    Parameters
    ----------
    readout : ReadoutBase
        New readout layer.

    Returns
    -------
    RCClassifierBase
        Updated model with new readout layer.
    """

    def where(m: "RCClassifierBase"):
        return m.readout

    new_model = eqx.tree_at(where, self, readout)
    return new_model

set_embedding

set_embedding(embedding: EmbedBase) -> RCClassifierBase

Replace embedding layer.

Parameters:

Name Type Description Default
embedding EmbedBase

New embedding layer.

required

Returns:

Type Description
RCClassifierBase

Updated model with new embedding layer.

Source code in src/orc/classifier/base.py
def set_embedding(self, embedding: EmbedBase) -> "RCClassifierBase":
    """Replace embedding layer.

    Parameters
    ----------
    embedding : EmbedBase
        New embedding layer.

    Returns
    -------
    RCClassifierBase
        Updated model with new embedding layer.
    """

    def where(m: "RCClassifierBase"):
        return m.embedding

    new_model = eqx.tree_at(where, self, embedding)
    return new_model

Models

ESNClassifier

ESNClassifier(data_dim: int, n_classes: int, res_dim: int, leak_rate: float = 0.6, bias: float = 1.6, embedding_scaling: float = 0.08, Wr_density: float = 0.02, Wr_spectral_radius: float = 0.8, dtype: type = float64, seed: int = 0, quadratic: bool = False, use_sparse_eigs: bool = True, state_repr: str = 'final')

Bases: RCClassifierBase

Basic implementation of ESN for classification tasks.

Attributes:

Name Type Description
res_dim int

Reservoir dimension.

data_dim int

Input data dimension.

n_classes int

Number of classification classes.

driver ESNDriver

Driver implementing the Echo State Network dynamics.

readout ReadoutBase

Trainable linear readout layer with out_dim=n_classes.

embedding LinearEmbedding

Untrainable linear embedding layer.

Methods:

Name Description
force

Teacher forces the reservoir with the input sequence.

classify

Classify an input sequence, returning class probabilities.

set_readout

Replace readout layer.

set_embedding

Replace embedding layer.

Initialize the ESN classifier.

Parameters:

Name Type Description Default
data_dim int

Dimension of the input data.

required
n_classes int

Number of classification classes.

required
res_dim int

Dimension of the reservoir adjacency matrix Wr.

required
leak_rate float

Integration leak rate of the reservoir dynamics.

0.6
bias float

Bias term for the reservoir dynamics.

1.6
embedding_scaling float

Scaling factor for the embedding layer.

0.08
Wr_density float

Density of the reservoir adjacency matrix Wr.

0.02
Wr_spectral_radius float

Largest eigenvalue of the reservoir adjacency matrix Wr.

0.8
dtype type

Data type of the model (jnp.float64 is highly recommended).

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
quadratic bool

Use quadratic nonlinearity in output, default False.

False
use_sparse_eigs bool

Whether to use sparse eigensolver for setting the spectral radius of wr. Default is True, which is recommended to save memory and compute time. If False, will use dense eigensolver which may be more accurate.

True
state_repr str

Reservoir state representation for classification. "final" uses the last reservoir state, "mean" averages states after spinup.

'final'
Source code in src/orc/classifier/models.py
def __init__(
    self,
    data_dim: int,
    n_classes: int,
    res_dim: int,
    leak_rate: float = 0.6,
    bias: float = 1.6,
    embedding_scaling: float = 0.08,
    Wr_density: float = 0.02,
    Wr_spectral_radius: float = 0.8,
    dtype: type = jnp.float64,
    seed: int = 0,
    quadratic: bool = False,
    use_sparse_eigs: bool = True,
    state_repr: str = "final",
) -> None:
    """
    Initialize the ESN classifier.

    Parameters
    ----------
    data_dim : int
        Dimension of the input data.
    n_classes : int
        Number of classification classes.
    res_dim : int
        Dimension of the reservoir adjacency matrix Wr.
    leak_rate : float
        Integration leak rate of the reservoir dynamics.
    bias : float
        Bias term for the reservoir dynamics.
    embedding_scaling : float
        Scaling factor for the embedding layer.
    Wr_density : float
        Density of the reservoir adjacency matrix Wr.
    Wr_spectral_radius : float
        Largest eigenvalue of the reservoir adjacency matrix Wr.
    dtype : type
        Data type of the model (jnp.float64 is highly recommended).
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    quadratic : bool
        Use quadratic nonlinearity in output, default False.
    use_sparse_eigs : bool
        Whether to use sparse eigensolver for setting the spectral radius of wr.
        Default is True, which is recommended to save memory and compute time. If
        False, will use dense eigensolver which may be more accurate.
    state_repr : str
        Reservoir state representation for classification.
        "final" uses the last reservoir state, "mean" averages states after spinup.
    """
    # Initialize the random key and reservoir dimension
    self.res_dim = res_dim
    self.seed = seed
    self.data_dim = data_dim
    key = jax.random.PRNGKey(seed)
    key_driver, key_readout, key_embedding = jax.random.split(key, 3)

    embedding = LinearEmbedding(
        in_dim=data_dim,
        res_dim=res_dim,
        seed=key_embedding[0],
        scaling=embedding_scaling,
    )
    driver = ESNDriver(
        res_dim=res_dim,
        seed=key_driver[0],
        leak=leak_rate,
        bias=bias,
        density=Wr_density,
        spectral_radius=Wr_spectral_radius,
        dtype=dtype,
        use_sparse_eigs=use_sparse_eigs,
    )
    if quadratic:
        readout = QuadraticReadout(
            out_dim=n_classes, res_dim=res_dim, seed=key_readout[0]
        )
    else:
        readout = LinearReadout(
            out_dim=n_classes, res_dim=res_dim, seed=key_readout[0]
        )

    super().__init__(
        driver=driver,
        readout=readout,
        embedding=embedding,
        n_classes=n_classes,
        state_repr=state_repr,
        dtype=dtype,
        seed=seed,
    )

Training Functions

train_ESNClassifier

train_ESNClassifier(model: ESNClassifier, train_seqs: Array, labels: Array, spinup: int = 0, beta: float = 8e-08) -> ESNClassifier

Training function for ESNClassifier.

Trains the classifier by forcing each input sequence through the reservoir, extracting a feature vector from the reservoir states, and solving ridge regression against one-hot encoded class labels.

Parameters:

Name Type Description Default
model ESNClassifier

ESNClassifier model to train.

required
train_seqs Array

Batch of training input sequences, (shape=(n_samples, seq_len, data_dim)).

required
labels Array

Integer class labels for each sequence, (shape=(n_samples,)). Values should be in [0, n_classes).

required
spinup int

Number of initial reservoir states to discard before extracting features. Only used when model.state_repr="mean".

0
beta float

Tikhonov regularization parameter.

8e-08

Returns:

Name Type Description
model ESNClassifier

Trained ESN classifier model.

Source code in src/orc/classifier/train.py
def train_ESNClassifier(
    model: ESNClassifier,
    train_seqs: Array,
    labels: Array,
    spinup: int = 0,
    beta: float = 8e-8,
) -> ESNClassifier:
    """Training function for ESNClassifier.

    Trains the classifier by forcing each input sequence through the reservoir,
    extracting a feature vector from the reservoir states, and solving ridge
    regression against one-hot encoded class labels.

    Parameters
    ----------
    model : ESNClassifier
        ESNClassifier model to train.
    train_seqs : Array
        Batch of training input sequences,
        (shape=(n_samples, seq_len, data_dim)).
    labels : Array
        Integer class labels for each sequence, (shape=(n_samples,)).
        Values should be in [0, n_classes).
    spinup : int
        Number of initial reservoir states to discard before extracting
        features. Only used when model.state_repr="mean".
    beta : float
        Tikhonov regularization parameter.

    Returns
    -------
    model : ESNClassifier
        Trained ESN classifier model.
    """
    if not isinstance(model, ESNClassifier):
        raise TypeError("Model must be an ESNClassifier.")

    if train_seqs.shape[0] != labels.shape[0]:
        raise ValueError("Number of training sequences must match number of labels.")

    n_samples = train_seqs.shape[0]
    initial_res_states = jnp.zeros((n_samples, model.res_dim), dtype=model.dtype)

    # Force all sequences through the reservoir in parallel via vmap
    all_res_seqs = jax.vmap(model.force)(train_seqs, initial_res_states)
    # all_res_seqs shape: (n_samples, seq_len, res_dim)

    # Extract feature vectors
    if model.state_repr == "final":
        feature_matrix = all_res_seqs[:, -1, :]  # (n_samples, res_dim)
    else:  # "mean"
        feature_matrix = jnp.mean(all_res_seqs[:, spinup:, :], axis=1)

    # Build one-hot target matrix
    one_hot_targets = jnp.zeros((n_samples, model.n_classes), dtype=model.dtype)
    one_hot_targets = one_hot_targets.at[jnp.arange(n_samples), labels].set(1.0)

    # Apply optional nonlinear transform
    if isinstance(model.readout, NonlinearReadout):
        feature_matrix = model.readout.nonlinear_transform(feature_matrix)

    # Solve ridge regression
    cmat = ridge_regression(feature_matrix, one_hot_targets, beta)
    cmat = cmat.reshape(1, cmat.shape[0], cmat.shape[1])

    def where(m):
        return m.readout.wout

    model = eqx.tree_at(where, model, cmat)

    return model