Skip to content

Forecaster API Reference

The orc.forecaster subpackage provides reservoir computing models for time series forecasting.

Base Classes

RCForecasterBase

RCForecasterBase(driver: DriverBase, readout: ReadoutBase, embedding: EmbedBase, chunks: int = 0, dtype: Float = float64, seed: int = 0)

Bases: Module, ABC

Base class for reservoir computer forecasters.

Defines the interface for the reservoir computer which includes the driver, readout and embedding layers.

Attributes:

Name Type Description
driver DriverBase

Driver layer of the reservoir computer.

readout ReadoutBase

Readout layer of the reservoir computer.

embedding EmbedBase

Embedding layer of the reservoir computer.

in_dim int

Dimension of the input data.

out_dim int

Dimension of the output data.

res_dim int

Dimension of the reservoir.

chunks int

Number of parallel reservoirs.

dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

seed int

Random seed for generating the PRNG key for the reservoir computer.

Methods:

Name Description
force

Teacher forces the reservoir with the input sequence.

set_readout

Replaces the readout layer of the reservoir computer.

set_embedding

Replaces the embedding layer of the reservoir computer.

forecast

Forecast from an initial reservoir state.

forecast_from_IC

Forecast from a sequence of spinup data.

Initialize RCForecaster Base.

Parameters:

Name Type Description Default
driver DriverBase

Driver layer of the reservoir computer.

required
readout ReadoutBase

Readout layer of the reservoir computer.

required
embedding EmbedBase

Embedding layer of the reservoir computer.

required
chunks int

Number of parallel reservoirs.

0
dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
Source code in src/orc/forecaster/base.py
def __init__(
    self,
    driver: DriverBase,
    readout: ReadoutBase,
    embedding: EmbedBase,
    chunks: int = 0,
    dtype: Float = jnp.float64,
    seed: int = 0,
) -> None:
    """Initialize RCForecaster Base.

    Parameters
    ----------
    driver : DriverBase
        Driver layer of the reservoir computer.
    readout : ReadoutBase
        Readout layer of the reservoir computer.
    embedding : EmbedBase
        Embedding layer of the reservoir computer.
    chunks : int
        Number of parallel reservoirs.
    dtype : type
        Data type of the reservoir computer (jnp.float64 is highly recommended).
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    """
    self.driver = driver
    self.readout = readout
    self.embedding = embedding
    self.in_dim = self.embedding.in_dim
    self.out_dim = self.readout.out_dim
    self.res_dim = self.driver.res_dim
    self.chunks = chunks
    self.dtype = dtype
    self.seed = seed

force

force(in_seq: Array, res_state: Array) -> Array

Teacher forces the reservoir.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to force the reservoir, (shape=(seq_len, data_dim)).

required
res_state Array

Initial reservoir state, (shape=(chunks, res_dim,)).

required

Returns:

Type Description
Array

Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).

Source code in src/orc/forecaster/base.py
@eqx.filter_jit
def force(self, in_seq: Array, res_state: Array) -> Array:
    """Teacher forces the reservoir.

    Parameters
    ----------
    in_seq: Array
        Input sequence to force the reservoir, (shape=(seq_len, data_dim)).
    res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim,)).

    Returns
    -------
    Array
        Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).
    """

    def scan_fn(state, in_vars):
        proj_vars = self.embedding.embed(in_vars)
        res_state = self.driver.advance(proj_vars, state)
        return (res_state, res_state)

    _, res_seq = jax.lax.scan(scan_fn, res_state, in_seq)
    return res_seq

__call__

__call__(in_seq: Array, res_state: Array) -> Array

Teacher forces the reservoir, wrapper for force method.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to force the reservoir, (shape=(seq_len, data_dim)).

required
res_state Array

Initial reservoir state, (shape=(chunks, res_dim,)).

required

Returns:

Type Description
Array

Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).

Source code in src/orc/forecaster/base.py
def __call__(self, in_seq: Array, res_state: Array) -> Array:
    """Teacher forces the reservoir, wrapper for `force` method.

    Parameters
    ----------
    in_seq: Array
        Input sequence to force the reservoir, (shape=(seq_len, data_dim)).
    res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim,)).

    Returns
    -------
    Array
        Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).
    """
    return self.force(in_seq, res_state)

set_readout

set_readout(readout: ReadoutBase)

Replace readout layer.

Parameters:

Name Type Description Default
readout ReadoutBase

New readout layer.

required

Returns:

Type Description
RCForecasterBase

Updated model with new readout layer.

Source code in src/orc/forecaster/base.py
def set_readout(self, readout: ReadoutBase):
    """Replace readout layer.

    Parameters
    ----------
    readout : ReadoutBase
        New readout layer.

    Returns
    -------
    RCForecasterBase
        Updated model with new readout layer.
    """

    def where(m: RCForecasterBase):
        return m.readout

    new_model = eqx.tree_at(where, self, readout)
    return new_model

set_embedding

set_embedding(embedding: EmbedBase)

Replace embedding layer.

Parameters:

Name Type Description Default
embedding EmbedBase

New embedding layer.

required

Returns:

Type Description
RCForecasterBase

Updated model with new embedding layer.

Source code in src/orc/forecaster/base.py
def set_embedding(self, embedding: EmbedBase):
    """Replace embedding layer.

    Parameters
    ----------
    embedding : EmbedBase
        New embedding layer.

    Returns
    -------
    RCForecasterBase
        Updated model with new embedding layer.
    """

    def where(m: RCForecasterBase):
        return m.embedding

    new_model = eqx.tree_at(where, self, embedding)
    return new_model

forecast

forecast(fcast_len: int, res_state: Array) -> Array

Forecast from an initial reservoir state.

Parameters:

Name Type Description Default
fcast_len int

Steps to forecast.

required
res_state Array

Initial reservoir state, (shape=(chunks, res_dim)).

required

Returns:

Type Description
Array

Forecasted states, (shape=(fcast_len, data_dim))

Source code in src/orc/forecaster/base.py
@eqx.filter_jit
def forecast(self, fcast_len: int, res_state: Array) -> Array:
    """Forecast from an initial reservoir state.

    Parameters
    ----------
    fcast_len : int
        Steps to forecast.
    res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim)).

    Returns
    -------
    Array
        Forecasted states, (shape=(fcast_len, data_dim))
    """

    def scan_fn(state, _):
        out_state = self.driver.advance(
            self.embedding.embed(self.readout.readout(state)), state
        )
        return (out_state, self.readout.readout(out_state))

    _, state_seq = jax.lax.scan(scan_fn, res_state, None, length=fcast_len - 1)
    pre_append_state = self.readout.readout(res_state)
    return jnp.vstack([pre_append_state, state_seq])

forecast_from_IC

forecast_from_IC(fcast_len: int, spinup_data: Array) -> Array

Forecast from a sequence of spinup data.

Parameters:

Name Type Description Default
fcast_len int

Steps to forecast.

required
spinup_data Array

Initial condition sequence, (shape=(seq_len, data_dim)).

required

Returns:

Type Description
Array

Forecasted states, (shape=(fcast_len, data_dim)).

Source code in src/orc/forecaster/base.py
@eqx.filter_jit
def forecast_from_IC(self, fcast_len: int, spinup_data: Array) -> Array:
    """Forecast from a sequence of spinup data.

    Parameters
    ----------
    fcast_len : int
        Steps to forecast.
    spinup_data : Array
        Initial condition sequence, (shape=(seq_len, data_dim)).

    Returns
    -------
    Array
        Forecasted states, (shape=(fcast_len, data_dim)).
    """
    if self.chunks > 0:
        res_seq = self.force(
            spinup_data, jnp.zeros((self.chunks, self.res_dim), dtype=self.dtype)
        )
    elif self.chunks == 0:
        res_seq = self.force(
            spinup_data, jnp.zeros((self.res_dim), dtype=self.dtype)
        )
    else:
        raise ValueError(f"chunks must be >= 0, but found chunks = {self.chunks}")

    return self.forecast(fcast_len, res_seq[-1])

CRCForecasterBase

CRCForecasterBase(driver: DriverBase, readout: ReadoutBase, embedding: EmbedBase, chunks: int = 0, dtype: Float = float64, seed: int = 0, solver: AbstractSolver = None, stepsize_controller: AbstractAdaptiveStepSizeController = None)

Bases: RCForecasterBase, ABC

Base class for continuous reservoir computer forecasters.

Override the force and forecast methods of RCForecasterBase to timestep the RC forward using a continuous time ODE solver.

Attributes:

Name Type Description
driver DriverBase

Driver layer of the reservoir computer.

readout ReadoutBase

Readout layer of the reservoir computer.

embedding EmbedBase

Embedding layer of the reservoir computer.

in_dim int

Dimension of the input data.

out_dim int

Dimension of the output data.

res_dim int

Dimension of the reservoir.

chunks int

Number of parallel reservoirs.

dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

seed int

Random seed for generating the PRNG key for the reservoir computer.

solver Solver

ODE solver to use for the reservoir computer.

stepsize_controller StepsizeController

Stepsize controller to use for the ODE solver.

Methods:

Name Description
force

Teacher forces the reservoir with the input sequence.

set_readout

Replaces the readout layer of the reservoir computer.

set_embedding

Replaces the embedding layer of the reservoir computer.

forecast

Forecast from an initial reservoir state.

forecast_from_IC

Forecast from a sequence of spinup data.

Initialize the continuous reservoir computer.

Parameters:

Name Type Description Default
driver DriverBase

Driver layer of the reservoir computer.

required
readout ReadoutBase

Readout layer of the reservoir computer.

required
embedding EmbedBase

Embedding layer of the reservoir computer.

required
chunks int

Number of parallel reservoirs.

0
dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
solver AbstractSolver

ODE solver to use for the reservoir computer.

None
stepsize_controller AbstractAdaptiveStepSizeController

Stepsize controller to use for the ODE solver.

None
Source code in src/orc/forecaster/base.py
def __init__(
    self,
    driver: DriverBase,
    readout: ReadoutBase,
    embedding: EmbedBase,
    chunks: int = 0,
    dtype: Float = jnp.float64,
    seed: int = 0,
    solver: diffrax.AbstractSolver = None,
    stepsize_controller: diffrax.AbstractAdaptiveStepSizeController = None,
):
    """Initialize the continuous reservoir computer.

    Parameters
    ----------
    driver : DriverBase
        Driver layer of the reservoir computer.
    readout : ReadoutBase
        Readout layer of the reservoir computer.
    embedding : EmbedBase
        Embedding layer of the reservoir computer.
    chunks : int
        Number of parallel reservoirs.
    dtype : type
        Data type of the reservoir computer (jnp.float64 is highly recommended).
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    solver : diffrax.AbstractSolver
        ODE solver to use for the reservoir computer.
    stepsize_controller : diffrax.AbstractAdaptiveStepSizeController
        Stepsize controller to use for the ODE solver.
    """
    super().__init__(driver, readout, embedding, chunks, dtype, seed)
    if solver is None:
        solver = diffrax.Tsit5()
    if stepsize_controller is None:
        stepsize_controller = diffrax.PIDController(
            rtol=1e-3, atol=1e-6, icoeff=1.0
        )
    self.solver = solver
    self.stepsize_controller = stepsize_controller

force

force(in_seq: Array, res_state: Array, ts: Array) -> Array

Teacher forces the reservoir.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to force the reservoir, (shape=(seq_len, data_dim)).

required
res_state Array

Initial reservoir state, (shape=(chunks, res_dim,)).

required
ts Array

Time steps for the input sequence, (shape=(seq_len,)).

required

Returns:

Type Description
Array

Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).

Source code in src/orc/forecaster/base.py
@eqx.filter_jit
def force(self, in_seq: Array, res_state: Array, ts: Array) -> Array:
    """
    Teacher forces the reservoir.

    Parameters
    ----------
    in_seq: Array
        Input sequence to force the reservoir, (shape=(seq_len, data_dim)).
    res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim,)).
    ts: Array
        Time steps for the input sequence, (shape=(seq_len,)).

    Returns
    -------
    Array
        Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).
    """
    # form interpolants
    coeffs = diffrax.backward_hermite_coefficients(ts, in_seq)
    in_seq_interp = diffrax.CubicInterpolation(ts, coeffs)

    # RC forced ODE definition
    @eqx.filter_jit
    def res_ode(t, r, args):
        interp = args
        proj_vars = self.embedding.embed(interp.evaluate(t))
        return self.driver.advance(proj_vars, r)

    # integrate RC
    dt0 = ts[1] - ts[0]
    ts = ts + dt0  # roll time forward one step for targets
    term = diffrax.ODETerm(res_ode)
    args = in_seq_interp
    save_at = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        term,
        t0=0.0,
        t1=ts[-1],
        dt0=dt0,
        y0=res_state,
        solver=self.solver,
        stepsize_controller=self.stepsize_controller,
        args=args,
        saveat=save_at,
        max_steps=None,
    )
    res_seq = sol.ys
    return res_seq

__call__

__call__(in_seq: Array, res_state: Array, ts: Array) -> Array

Teacher forces the reservoir, wrapper for force method.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to force the reservoir, (shape=(seq_len, data_dim)).

required
res_state Array

Initial reservoir state, (shape=(chunks, res_dim,)).

required
ts Array

Time steps for the input sequence, (shape=(seq_len,)).

required

Returns:

Type Description
Array

Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).

Source code in src/orc/forecaster/base.py
def __call__(self, in_seq: Array, res_state: Array, ts: Array) -> Array:
    """
    Teacher forces the reservoir, wrapper for `force` method.

    Parameters
    ----------
    in_seq: Array
        Input sequence to force the reservoir, (shape=(seq_len, data_dim)).
    res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim,)).
    ts: Array
        Time steps for the input sequence, (shape=(seq_len,)).

    Returns
    -------
    Array
        Forced reservoir sequence, (shape=(seq_len, chunks, res_dim)).
    """
    return self.force(in_seq, res_state, ts)

forecast

forecast(ts: Array, res_state: Array) -> Array

Forecast from an initial reservoir state.

Parameters:

Name Type Description Default
ts Array

Time steps for the forecast, (shape=(fcast_len,)).

required
res_state Array

Initial reservoir state, (shape=(chunks, res_dim)).

required

Returns:

Type Description
Array

Forecasted states, (shape=(fcast_len, data_dim))

Source code in src/orc/forecaster/base.py
@eqx.filter_jit
def forecast(self, ts: Array, res_state: Array) -> Array:
    """Forecast from an initial reservoir state.

    Parameters
    ----------
    ts : Array
        Time steps for the forecast, (shape=(fcast_len,)).
    res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim)).

    Returns
    -------
    Array
        Forecasted states, (shape=(fcast_len, data_dim))
    """

    # RC autonomous ODE definition
    @eqx.filter_jit
    def res_ode(t, r, args):
        out_state = self.driver.advance(
            self.embedding.embed(self.readout.readout(r)), r
        )
        return out_state

    # integrate RC
    dt0 = ts[1] - ts[0]
    term = diffrax.ODETerm(res_ode)
    save_at = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        term,
        t0=0.0,
        t1=ts[-1],
        dt0=dt0,
        y0=res_state,
        solver=self.solver,
        stepsize_controller=self.stepsize_controller,
        saveat=save_at,
        max_steps=None,
    )
    res_seq = sol.ys
    return eqx.filter_vmap(self.readout.readout)(res_seq)

forecast_from_IC

forecast_from_IC(ts: Array, spinup_data: Array, spinup_ts: Array = None) -> Array

Forecast from a sequence of spinup data.

Parameters:

Name Type Description Default
ts Array

Time steps for the forecast, (shape=(fcast_len,)).

required
spinup_data Array

Initial condition sequence, (shape=(seq_len, data_dim)).

required
spinup_ts Array

Time steps for the spinup data, (shape=(seq_len,)). If None, the spinup data is assumed to have the same dt as the forecast data. If not None, the spinup data Default is None.

None

Returns:

Type Description
Array

Forecasted states, (shape=(fcast_len, data_dim)).

Source code in src/orc/forecaster/base.py
@eqx.filter_jit
def forecast_from_IC(
    self, ts: Array, spinup_data: Array, spinup_ts: Array = None
) -> Array:
    """Forecast from a sequence of spinup data.

    Parameters
    ----------
    ts : Array
        Time steps for the forecast, (shape=(fcast_len,)).
    spinup_data : Array
        Initial condition sequence, (shape=(seq_len, data_dim)).
    spinup_ts : Array
        Time steps for the spinup data, (shape=(seq_len,)).
        If None, the spinup data is assumed to have the same dt
        as the forecast data.  If not None, the spinup data
        Default is None.

    Returns
    -------
    Array
        Forecasted states, (shape=(fcast_len, data_dim)).
    """
    if spinup_ts is None:
        dt0 = ts[1] - ts[0]
        spinup_ts = jnp.arange(0.0, spinup_data.shape[0], dtype=self.dtype) * dt0

    if self.chunks > 0:
        res_seq = self.force(
            spinup_data,
            jnp.zeros((self.chunks, self.res_dim), dtype=self.dtype),
            spinup_ts,
        )
    elif self.chunks == 0:
        res_seq = self.force(
            spinup_data, jnp.zeros((self.res_dim), dtype=self.dtype), spinup_ts
        )
    else:
        raise ValueError(f"chunks must be >= 0, but found chunks = {self.chunks}")

    return self.forecast(ts, res_seq[-1])

Models

ESNForecaster

ESNForecaster(data_dim: int, res_dim: int, leak_rate: float = 0.6, bias: float = 1.6, embedding_scaling: float = 0.08, Wr_density: float = 0.02, Wr_spectral_radius: float = 0.8, dtype: type = float64, seed: int = 0, chunks: int = 1, locality: int = 0, quadratic: bool = False, periodic: bool = True, use_sparse_eigs: bool = True)

Bases: RCForecasterBase

Basic implementation of ESN for forecasting.

Attributes:

Name Type Description
res_dim int

Reservoir dimension.

data_dim int

Input/output dimension.

driver ParallelESNDriver

Driver implmenting the Echo State Network dynamics.

readout BaseReadout

Trainable linear readout layer.

embedding ParallelLinearEmbedding

Untrainable linear embedding layer.

Methods:

Name Description
force

Teacher forces the reservoir with sequence in_seq and init. cond. res_state.

forecast

Perform a forecast of fcast_len steps from res_state.

forecast_from_IC

Forecast from a sequence of spinup data.

set_readout

Replace readout layer.

set_embedding

Replace embedding layer.

Initialize the ESN model.

Parameters:

Name Type Description Default
data_dim int

Dimension of the input data.

required
res_dim int

Dimension of the reservoir adjacency matrix Wr.

required
leak_rate float

Integration leak rate of the reservoir dynamics.

0.6
bias float

Bias term for the reservoir dynamics.

1.6
embedding_scaling float

Scaling factor for the embedding layer.

0.08
Wr_density float

Density of the reservoir adjacency matrix Wr.

0.02
Wr_spectral_radius float

Largest eigenvalue of the reservoir adjacency matrix Wr.

0.8
dtype type

Data type of the model (jnp.float64 is highly recommended).

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
chunks int

Number of parallel reservoirs, must evenly divide data_dim.

1
locality int

Overlap in adjacent parallel reservoirs.

0
quadratic bool

Use quadratic nonlinearity in output, default False.

False
periodic bool

Periodic BCs for embedding layer.

True
use_sparse_eigs bool

Whether to use sparse eigensolver for setting the spectral radius of wr. Default is True, which is recommended to save memory and compute time. If False, will use dense eigensolver which may be more accurate.

True
Source code in src/orc/forecaster/models.py
def __init__(
    self,
    data_dim: int,
    res_dim: int,
    leak_rate: float = 0.6,
    bias: float = 1.6,
    embedding_scaling: float = 0.08,
    Wr_density: float = 0.02,
    Wr_spectral_radius: float = 0.8,
    dtype: type = jnp.float64,
    seed: int = 0,
    chunks: int = 1,
    locality: int = 0,
    quadratic: bool = False,
    periodic: bool = True,
    use_sparse_eigs: bool = True,
) -> None:
    """
    Initialize the ESN model.

    Parameters
    ----------
    data_dim : int
        Dimension of the input data.
    res_dim : int
        Dimension of the reservoir adjacency matrix Wr.
    leak_rate : float
        Integration leak rate of the reservoir dynamics.
    bias : float
        Bias term for the reservoir dynamics.
    embedding_scaling : float
        Scaling factor for the embedding layer.
    Wr_density : float
        Density of the reservoir adjacency matrix Wr.
    Wr_spectral_radius : float
        Largest eigenvalue of the reservoir adjacency matrix Wr.
    dtype : type
        Data type of the model (jnp.float64 is highly recommended).
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    chunks : int
        Number of parallel reservoirs, must evenly divide data_dim.
    locality : int
        Overlap in adjacent parallel reservoirs.
    quadratic : bool
        Use quadratic nonlinearity in output, default False.
    periodic : bool
        Periodic BCs for embedding layer.
    use_sparse_eigs : bool
        Whether to use sparse eigensolver for setting the spectral radius of wr.
        Default is True, which is recommended to save memory and compute time. If
        False, will use dense eigensolver which may be more accurate.
    """
    # Initialize the random key and reservoir dimension
    self.res_dim = res_dim
    self.seed = seed
    self.data_dim = data_dim
    key = jax.random.PRNGKey(seed)
    key_driver, key_readout, key_embedding = jax.random.split(key, 3)

    # init in embedding, driver and readout
    embedding = ParallelLinearEmbedding(
        in_dim=data_dim,
        res_dim=res_dim,
        seed=key_embedding[0],
        scaling=embedding_scaling,
        chunks=chunks,
        locality=locality,
        periodic=periodic,
    )
    driver = ParallelESNDriver(
        res_dim=res_dim,
        seed=key_driver[0],
        leak=leak_rate,
        bias=bias,
        density=Wr_density,
        spectral_radius=Wr_spectral_radius,
        chunks=chunks,
        dtype=dtype,
        use_sparse_eigs=use_sparse_eigs,
    )
    if quadratic:
        readout = ParallelQuadraticReadout(
            out_dim=data_dim, res_dim=res_dim, seed=key_readout[0], chunks=chunks
        )
    else:
        readout = ParallelLinearReadout(
            out_dim=data_dim, res_dim=res_dim, seed=key_readout[0], chunks=chunks
        )

    super().__init__(
        driver=driver,
        readout=readout,
        embedding=embedding,
        dtype=dtype,
        seed=seed,
    )
    self.chunks = chunks

CESNForecaster

CESNForecaster(data_dim: int, res_dim: int, time_const: float = 50.0, bias: float = 1.6, embedding_scaling: float = 0.08, Wr_density: float = 0.02, Wr_spectral_radius: float = 0.8, dtype: type = float64, seed: int = 0, chunks: int = 1, locality: int = 0, quadratic: bool = False, periodic: bool = True, use_sparse_eigs: bool = True, solver: AbstractSolver = None, stepsize_controller: AbstractAdaptiveStepSizeController = None)

Bases: CRCForecasterBase

Basic implementation of a Continuous ESN for forecasting.

Attributes:

Name Type Description
res_dim int

Reservoir dimension.

data_dim int

Input/output dimension.

driver ParallelESNDriver

Driver implementing the Echo State Network dynamics in continuous time.

readout BaseReadout

Trainable linear readout layer.

embedding ParallelLinearEmbedding

Untrainable linear embedding layer.

Methods:

Name Description
force

Teacher forces the reservoir with sequence in_seq and init. cond. res_state.

forecast

Perform a forecast of fcast_len steps from res_state.

forecast_from_IC

Forecast from a sequence of spinup data.

set_readout

Replace readout layer.

set_embedding

Replace embedding layer.

Initialize the CESN model.

Parameters:

Name Type Description Default
data_dim int

Dimension of the input data.

required
res_dim int

Dimension of the reservoir adjacency matrix Wr.

required
time_const float

Time constant of the reservoir dynamics.

50.0
bias float

Bias term for the reservoir dynamics.

1.6
embedding_scaling float

Scaling factor for the embedding layer.

0.08
Wr_density float

Density of the reservoir adjacency matrix Wr.

0.02
Wr_spectral_radius float

Largest eigenvalue of the reservoir adjacency matrix Wr.

0.8
dtype type

Data type of the model (jnp.float64 is highly recommended).

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
chunks int

Number of parallel reservoirs, must evenly divide data_dim.

1
locality int

Overlap in adjacent parallel reservoirs.

0
quadratic bool

Use quadratic nonlinearity in output, default False.

False
periodic bool

Periodic BCs for embedding layer.

True
use_sparse_eigs bool

Whether to use sparse eigensolver for setting the spectral radius of wr. Default is True, which is recommended to save memory and compute time. If False, will use dense eigensolver which may be more accurate.

True
Source code in src/orc/forecaster/models.py
def __init__(
    self,
    data_dim: int,
    res_dim: int,
    time_const: float = 50.0,
    bias: float = 1.6,
    embedding_scaling: float = 0.08,
    Wr_density: float = 0.02,
    Wr_spectral_radius: float = 0.8,
    dtype: type = jnp.float64,
    seed: int = 0,
    chunks: int = 1,
    locality: int = 0,
    quadratic: bool = False,
    periodic: bool = True,
    use_sparse_eigs: bool = True,
    solver: diffrax.AbstractSolver = None,
    stepsize_controller: diffrax.AbstractAdaptiveStepSizeController = None,
) -> None:
    """
    Initialize the CESN model.

    Parameters
    ----------
    data_dim : int
        Dimension of the input data.
    res_dim : int
        Dimension of the reservoir adjacency matrix Wr.
    time_const : float
        Time constant of the reservoir dynamics.
    bias : float
        Bias term for the reservoir dynamics.
    embedding_scaling : float
        Scaling factor for the embedding layer.
    Wr_density : float
        Density of the reservoir adjacency matrix Wr.
    Wr_spectral_radius : float
        Largest eigenvalue of the reservoir adjacency matrix Wr.
    dtype : type
        Data type of the model (jnp.float64 is highly recommended).
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    chunks : int
        Number of parallel reservoirs, must evenly divide data_dim.
    locality : int
        Overlap in adjacent parallel reservoirs.
    quadratic : bool
        Use quadratic nonlinearity in output, default False.
    periodic : bool
        Periodic BCs for embedding layer.
    use_sparse_eigs : bool
        Whether to use sparse eigensolver for setting the spectral radius of wr.
        Default is True, which is recommended to save memory and compute time. If
        False, will use dense eigensolver which may be more accurate.
    """
    # Initialize the random key and reservoir dimension
    self.res_dim = res_dim
    self.seed = seed
    self.data_dim = data_dim
    key = jax.random.PRNGKey(seed)
    key_driver, key_readout, key_embedding = jax.random.split(key, 3)

    # init in embedding, driver and readout
    embedding = ParallelLinearEmbedding(
        in_dim=data_dim,
        res_dim=res_dim,
        seed=key_embedding[0],
        scaling=embedding_scaling,
        chunks=chunks,
        locality=locality,
        periodic=periodic,
    )
    driver = ParallelESNDriver(
        res_dim=res_dim,
        seed=key_driver[0],
        time_const=time_const,
        bias=bias,
        density=Wr_density,
        spectral_radius=Wr_spectral_radius,
        chunks=chunks,
        mode="continuous",
        dtype=dtype,
        use_sparse_eigs=use_sparse_eigs,
    )
    if quadratic:
        readout = ParallelQuadraticReadout(
            out_dim=data_dim, res_dim=res_dim, seed=key_readout[0], chunks=chunks
        )
    else:
        readout = ParallelLinearReadout(
            out_dim=data_dim, res_dim=res_dim, seed=key_readout[0], chunks=chunks
        )

    if solver is None:
        solver = diffrax.Tsit5()
    if stepsize_controller is None:
        stepsize_controller = diffrax.PIDController(
            rtol=1e-3, atol=1e-6, icoeff=1.0
        )

    super().__init__(
        driver=driver,
        readout=readout,
        embedding=embedding,
        dtype=dtype,
        seed=seed,
        solver=solver,
        stepsize_controller=stepsize_controller,
    )
    self.chunks = chunks

EnsembleESNForecaster

EnsembleESNForecaster(data_dim: int, res_dim: int, leak_rate: float = 0.6, bias: float = 1.6, embedding_scaling: float = 0.08, Wr_density: float = 0.02, Wr_spectral_radius: float = 0.8, dtype: type = float64, seed: int = 0, chunks: int = 1, use_sparse_eigs: bool = True)

Bases: RCForecasterBase

Ensembled ESNs for forecasting.

Attributes:

Name Type Description
res_dim int

Reservoir dimension.

data_dim int

Input/output dimension.

driver ParallelESNDriver

Driver implmenting the Echo State Network dynamics.

readout EnsembleLinearReadout

Trainable linear readout layer.

embedding EnsembleLinearEmbedding

Untrainable linear embedding layer.

Methods:

Name Description
force

Teacher forces the reservoir with sequence in_seq and init. cond. res_state.

forecast

Perform a forecast of fcast_len steps from res_state.

forecast_from_IC

Forecast from a sequence of spinup data.

set_readout

Replace readout layer.

set_embedding

Replace embedding layer.

Initialize the ESN model.

Parameters:

Name Type Description Default
data_dim int

Dimension of the input data.

required
res_dim int

Dimension of the reservoir adjacency matrix Wr.

required
leak_rate float

Integration leak rate of the reservoir dynamics.

0.6
bias float

Bias term for the reservoir dynamics.

1.6
embedding_scaling float

Scaling factor for the embedding layer.

0.08
Wr_density float

Density of the reservoir adjacency matrix Wr.

0.02
Wr_spectral_radius float

Largest eigenvalue of the reservoir adjacency matrix Wr.

0.8
dtype type

Data type of the model (jnp.float64 is highly recommended).

float64
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
chunks int

Number of parallel reservoirs, must evenly divide data_dim.

1
use_sparse_eigs bool

Whether to use sparse eigensolver for setting the spectral radius of wr. Default is True, which is recommended to save memory and compute time. If False, will use dense eigensolver which may be more accurate.

True
Source code in src/orc/forecaster/models.py
def __init__(
    self,
    data_dim: int,
    res_dim: int,
    leak_rate: float = 0.6,
    bias: float = 1.6,
    embedding_scaling: float = 0.08,
    Wr_density: float = 0.02,
    Wr_spectral_radius: float = 0.8,
    dtype: type = jnp.float64,
    seed: int = 0,
    chunks: int = 1,
    use_sparse_eigs: bool = True,
) -> None:
    """
    Initialize the ESN model.

    Parameters
    ----------
    data_dim : int
        Dimension of the input data.
    res_dim : int
        Dimension of the reservoir adjacency matrix Wr.
    leak_rate : float
        Integration leak rate of the reservoir dynamics.
    bias : float
        Bias term for the reservoir dynamics.
    embedding_scaling : float
        Scaling factor for the embedding layer.
    Wr_density : float
        Density of the reservoir adjacency matrix Wr.
    Wr_spectral_radius : float
        Largest eigenvalue of the reservoir adjacency matrix Wr.
    dtype : type
        Data type of the model (jnp.float64 is highly recommended).
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    chunks : int
        Number of parallel reservoirs, must evenly divide data_dim.
    use_sparse_eigs : bool
        Whether to use sparse eigensolver for setting the spectral radius of wr.
        Default is True, which is recommended to save memory and compute time. If
        False, will use dense eigensolver which may be more accurate.
    """
    # Initialize the random key and reservoir dimension
    self.res_dim = res_dim
    self.seed = seed
    self.data_dim = data_dim
    key = jax.random.PRNGKey(seed)
    key_driver, key_readout, key_embedding = jax.random.split(key, 3)

    # init in embedding, driver and readout
    embedding = EnsembleLinearEmbedding(
        in_dim=data_dim,
        res_dim=res_dim,
        seed=key_embedding[0],
        scaling=embedding_scaling,
        chunks=chunks,
    )
    driver = ParallelESNDriver(
        res_dim=res_dim,
        seed=key_driver[0],
        leak=leak_rate,
        bias=bias,
        density=Wr_density,
        spectral_radius=Wr_spectral_radius,
        chunks=chunks,
        dtype=dtype,
        use_sparse_eigs=use_sparse_eigs,
    )

    readout = EnsembleLinearReadout(
        out_dim=data_dim, res_dim=res_dim, seed=key_readout[0], chunks=chunks
    )

    super().__init__(
        driver=driver,
        readout=readout,
        embedding=embedding,
        dtype=dtype,
        seed=seed,
    )
    self.chunks = chunks

Training Functions

train_ESNForecaster

train_ESNForecaster(model: ESNForecaster, train_seq: Array, target_seq: Array = None, spinup: int = 0, initial_res_state: Array = None, beta: float = 8e-08, batch_size: int = None) -> tuple[ESNForecaster, Array]

Training function for ESNForecaster.

Parameters:

Name Type Description Default
model ESNForecaster

ESNForecaster model to train.

required
train_seq Array

Training input sequence for reservoir, (shape=(seq_len, data_dim)).

required
target_seq Array

Target sequence for training reservoir, (shape=(seq_len, data_dim)).

None
initial_res_state Array

Initial reservoir state, (shape=(chunks, res_dim,)).

None
spinup int

Initial transient of reservoir states to discard.

0
beta float

Tikhonov regularization parameter.

8e-08
batch_size int

Number of parallel reservoirs to process in each batch for ridge regression. If None (default), processes all reservoirs at once. Use smaller values to reduce memory usage for large numbers of parallel reservoirs.

None

Returns:

Name Type Description
model ESNForecaster

Trained ESN model.

res_seq Array

Training sequence of reservoir states.

Source code in src/orc/forecaster/train.py
def train_ESNForecaster(
    model: ESNForecaster,
    train_seq: Array,
    target_seq: Array = None,
    spinup: int = 0,
    initial_res_state: Array = None,
    beta: float = 8e-8,
    batch_size: int = None,
) -> tuple[ESNForecaster, Array]:
    """Training function for ESNForecaster.

    Parameters
    ----------
    model : ESNForecaster
        ESNForecaster model to train.
    train_seq : Array
        Training input sequence for reservoir, (shape=(seq_len, data_dim)).
    target_seq : Array
        Target sequence for training reservoir, (shape=(seq_len, data_dim)).
    initial_res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim,)).
    spinup : int
        Initial transient of reservoir states to discard.
    beta : float
        Tikhonov regularization parameter.
    batch_size : int, optional
        Number of parallel reservoirs to process in each batch for ridge regression.
        If None (default), processes all reservoirs at once. Use smaller values
        to reduce memory usage for large numbers of parallel reservoirs.

    Returns
    -------
    model : ESNForecaster
        Trained ESN model.
    res_seq : Array
        Training sequence of reservoir states.
    """
    # Check that model is an ESN
    if not isinstance(model, ESNForecaster):
        raise TypeError("Model must be an ESNForecaster.")

    # check that spinup is less than the length of the training sequence
    if spinup >= train_seq.shape[0]:
        raise ValueError(
            "spinup must be less than the length of the training sequence."
        )

    if initial_res_state is None:
        initial_res_state = jnp.zeros(
            (
                model.embedding.chunks,
                model.res_dim,
            ),
            dtype=model.dtype,
        )

    if target_seq is None:
        tot_seq = train_seq
        target_seq = train_seq[1:, :]
        train_seq = train_seq[:-1, :]
    else:
        tot_seq = jnp.vstack((train_seq, target_seq[-1:]))

    tot_res_seq = model.force(tot_seq, initial_res_state)
    res_seq = tot_res_seq[:-1]
    if isinstance(model.readout, ParallelNonlinearReadout):
        res_seq_train = model.readout.nonlinear_transform(res_seq)
    else:
        res_seq_train = res_seq

    if batch_size is None:
        cmat = _solve_all_ridge_reg(
            res_seq_train[spinup:],
            target_seq[spinup:].reshape(
                res_seq[spinup:].shape[0], res_seq.shape[1], -1
            ),
            beta,
        )
    else:
        cmat = _solve_all_ridge_reg_batched(
            res_seq_train[spinup:],
            target_seq[spinup:].reshape(
                res_seq[spinup:].shape[0], res_seq.shape[1], -1
            ),
            beta,
            batch_size,
        )

    def where(m):
        return m.readout.wout

    model = eqx.tree_at(where, model, cmat)

    return model, tot_res_seq

train_CESNForecaster

train_CESNForecaster(model: CESNForecaster, train_seq: Array, t_train: Array, target_seq: Array = None, spinup: int = 0, initial_res_state: Array = None, beta: float = 8e-08, batch_size: int = None) -> tuple[CESNForecaster, Array]

Training function for CESNForecaster.

Parameters:

Name Type Description Default
model CESNForecaster

CESNForecaster model to train.

required
train_seq Array

Training input sequence for reservoir, (shape=(seq_len, data_dim)).

required
t_train Array

time vector corresponding to the training sequence, (shape=(seq_len,)).

required
target_seq Array

Target sequence for training reservoir, (shape=(seq_len, data_dim)).

None
initial_res_state Array

Initial reservoir state, (shape=(chunks, res_dim,)).

None
spinup int

Initial transient of reservoir states to discard.

0
beta float

Tikhonov regularization parameter.

8e-08
batch_size int

Number of parallel reservoirs to process in each batch for ridge regression. If None (default), processes all reservoirs at once. Use smaller values to reduce memory usage for large numbers of parallel reservoirs.

None

Returns:

Name Type Description
model CESNForecaster

Trained CESN model.

res_seq Array

Training sequence of reservoir states.

Source code in src/orc/forecaster/train.py
def train_CESNForecaster(
    model: CESNForecaster,
    train_seq: Array,
    t_train: Array,
    target_seq: Array = None,
    spinup: int = 0,
    initial_res_state: Array = None,
    beta: float = 8e-8,
    batch_size: int = None,
) -> tuple[CESNForecaster, Array]:
    """Training function for CESNForecaster.

    Parameters
    ----------
    model : CESNForecaster
        CESNForecaster model to train.
    train_seq : Array
        Training input sequence for reservoir, (shape=(seq_len, data_dim)).
    t_train : Array
        time vector corresponding to the training sequence, (shape=(seq_len,)).
    target_seq : Array
        Target sequence for training reservoir, (shape=(seq_len, data_dim)).
    initial_res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim,)).
    spinup : int
        Initial transient of reservoir states to discard.
    beta : float
        Tikhonov regularization parameter.
    batch_size : int, optional
        Number of parallel reservoirs to process in each batch for ridge regression.
        If None (default), processes all reservoirs at once. Use smaller values
        to reduce memory usage for large numbers of parallel reservoirs.

    Returns
    -------
    model : CESNForecaster
        Trained CESN model.
    res_seq : Array
        Training sequence of reservoir states.
    """
    # check that model is continuous
    if not isinstance(model, CESNForecaster):
        raise TypeError("Model must be a CESNForecaster.")

    # check that train_seq and t_train have the same length
    if train_seq.shape[0] != t_train.shape[0]:
        raise ValueError("train_seq and t_train must have the same length.")

    # check that spinup is less than the length of the training sequence
    if spinup >= train_seq.shape[0]:
        raise ValueError(
            "spinup must be less than the length of the training sequence."
        )

    if initial_res_state is None:
        initial_res_state = jnp.zeros(
            (
                model.embedding.chunks,
                model.res_dim,
            ),
            dtype=model.dtype,
        )

    if target_seq is None:
        tot_seq = train_seq
        target_seq = train_seq[1:, :]
        train_seq = train_seq[:-1, :]
    else:
        tot_seq = jnp.vstack((train_seq, target_seq[-1:]))

    tot_res_seq = model.force(tot_seq, initial_res_state, ts=t_train)
    res_seq = tot_res_seq[:-1]
    if isinstance(model.readout, ParallelNonlinearReadout):
        res_seq_train = eqx.filter_vmap(model.readout.nonlinear_transform)(res_seq)
    else:
        res_seq_train = res_seq

    if batch_size is None:
        cmat = _solve_all_ridge_reg(
            res_seq_train[spinup:],
            target_seq[spinup:].reshape(
                res_seq[spinup:].shape[0], res_seq.shape[1], -1
            ),
            beta,
        )
    else:
        cmat = _solve_all_ridge_reg_batched(
            res_seq_train[spinup:],
            target_seq[spinup:].reshape(
                res_seq[spinup:].shape[0], res_seq.shape[1], -1
            ),
            beta,
            batch_size,
        )

    def where(m):
        return m.readout.wout

    model = eqx.tree_at(where, model, cmat)

    return model, tot_res_seq

train_EnsembleESNForecaster

train_EnsembleESNForecaster(model: EnsembleESNForecaster, train_seq: Array, target_seq: Array | None = None, spinup: int = 0, initial_res_state: Array | None = None, beta: float = 8e-08, batch_size: int | None = None) -> tuple[ESNForecaster, Array]

Training function for ESNForecaster.

Parameters:

Name Type Description Default
model ESNForecaster

ESNForecaster model to train.

required
train_seq Array

Training input sequence for reservoir, (shape=(seq_len, data_dim)).

required
target_seq Array

Target sequence for training reservoir, (shape=(seq_len, data_dim)).

None
initial_res_state Array

Initial reservoir state, (shape=(chunks, res_dim,)).

None
spinup int

Initial transient of reservoir states to discard.

0
beta float

Tikhonov regularization parameter.

8e-08
batch_size int

Number of parallel reservoirs to process in each batch for ridge regression. If None (default), processes all reservoirs at once. Use smaller values to reduce memory usage for large numbers of parallel reservoirs.

None

Returns:

Name Type Description
model ESNForecaster

Trained ESN model.

res_seq Array

Training sequence of reservoir states.

Source code in src/orc/forecaster/train.py
def train_EnsembleESNForecaster(
    model: EnsembleESNForecaster,
    train_seq: Array,
    target_seq: Array | None = None,
    spinup: int = 0,
    initial_res_state: Array | None = None,
    beta: float = 8e-8,
    batch_size: int | None = None,
) -> tuple[ESNForecaster, Array]:
    """Training function for ESNForecaster.

    Parameters
    ----------
    model : ESNForecaster
        ESNForecaster model to train.
    train_seq : Array
        Training input sequence for reservoir, (shape=(seq_len, data_dim)).
    target_seq : Array
        Target sequence for training reservoir, (shape=(seq_len, data_dim)).
    initial_res_state : Array
        Initial reservoir state, (shape=(chunks, res_dim,)).
    spinup : int
        Initial transient of reservoir states to discard.
    beta : float
        Tikhonov regularization parameter.
    batch_size : int, optional
        Number of parallel reservoirs to process in each batch for ridge regression.
        If None (default), processes all reservoirs at once. Use smaller values
        to reduce memory usage for large numbers of parallel reservoirs.

    Returns
    -------
    model : ESNForecaster
        Trained ESN model.
    res_seq : Array
        Training sequence of reservoir states.
    """
    # Check that model is an ESN
    if not isinstance(model, EnsembleESNForecaster):
        raise TypeError("Model must be an EnsembleESNForecaster.")

    # check that spinup is less than the length of the training sequence
    if spinup >= train_seq.shape[0]:
        raise ValueError(
            "spinup must be less than the length of the training sequence."
        )

    if initial_res_state is None:
        initial_res_state = jnp.zeros(
            (
                model.embedding.chunks,
                model.res_dim,
            ),
            dtype=model.dtype,
        )

    if target_seq is None:
        tot_seq = train_seq
        target_seq = train_seq[1:, :]
        train_seq = train_seq[:-1, :]
    else:
        tot_seq = jnp.vstack((train_seq, target_seq[-1:]))

    tot_res_seq = model.force(tot_seq, initial_res_state)
    res_seq = tot_res_seq[:-1]
    res_seq_train = res_seq

    repeated_target_seq = jnp.repeat(target_seq[:, None, :], model.chunks, axis=1)
    if batch_size is None:
        cmat = _solve_all_ridge_reg(
            res_seq_train[spinup:],
            repeated_target_seq[spinup:],
            beta,
        )
    else:
        cmat = _solve_all_ridge_reg_batched(
            res_seq_train[spinup:],
            repeated_target_seq[spinup:],
            beta,
            batch_size,
        )

    def where(m):
        return m.readout.wout

    model = eqx.tree_at(where, model, cmat)

    return model, tot_res_seq