Skip to content

Control API Reference

The orc.control subpackage provides reservoir computing models for dynamical system control.

Base Classes

RCControllerBase

RCControllerBase(driver: DriverBase, readout: ReadoutBase, embedding: EmbedBase, in_dim: int, control_dim: int, dtype: Float = float64, alpha_1: float = 100, alpha_2: float = 1, alpha_3: float = 5, seed: int = 0)

Bases: Module, ABC

Base class for reservoir computer controllers.

Defines the interface for the reservoir computer controller which includes the driver, readout and embedding layers. Unlike the forecaster, the controller handles an additional control input at each time step.

Attributes:

Name Type Description
driver DriverBase

Driver layer of the reservoir computer.

readout ReadoutBase

Readout layer of the reservoir computer.

embedding EmbedBase

Embedding layer of the reservoir computer. Should accept concatenated [input, control] vectors.

in_dim int

Dimension of the system input data.

control_dim int

Dimension of the control input.

out_dim int

Dimension of the output data.

res_dim int

Dimension of the reservoir.

dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

alpha_1 float

Weight for trajectory deviation penalty in control optimization.

alpha_2 float

Weight for control magnitude penalty in control optimization.

alpha_3 float

Weight for control derivative penalty in control optimization.

seed int

Random seed for generating the PRNG key for the reservoir computer.

Methods:

Name Description
force

Teacher forces the reservoir with input and control sequences.

apply_control

Apply a predefined control sequence in closed-loop.

set_readout

Replaces the readout layer of the reservoir computer.

set_embedding

Replaces the embedding layer of the reservoir computer.

compute_penalty

Compute the control penalty for a given control sequence.

compute_control

Compute optimal control sequence to track a reference trajectory.

Initialize RCController Base.

Parameters:

Name Type Description Default
driver DriverBase

Driver layer of the reservoir computer.

required
readout ReadoutBase

Readout layer of the reservoir computer.

required
embedding EmbedBase

Embedding layer of the reservoir computer.

required
in_dim int

Dimension of the system input data.

required
control_dim int

Dimension of the control input.

required
dtype type

Data type of the reservoir computer (jnp.float64 is highly recommended).

float64
alpha_1 float

Weight for trajectory deviation penalty in control optimization.

100
alpha_2 float

Weight for control magnitude penalty in control optimization.

1
alpha_3 float

Weight for control derivative penalty in control optimization.

5
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
Source code in src/orc/control/base.py
def __init__(
    self,
    driver: DriverBase,
    readout: ReadoutBase,
    embedding: EmbedBase,
    in_dim: int,
    control_dim: int,
    dtype: Float = jnp.float64,
    alpha_1: float = 100,
    alpha_2: float = 1,
    alpha_3: float = 5,
    seed: int = 0,
) -> None:
    """Initialize RCController Base.

    Parameters
    ----------
    driver : DriverBase
        Driver layer of the reservoir computer.
    readout : ReadoutBase
        Readout layer of the reservoir computer.
    embedding : EmbedBase
        Embedding layer of the reservoir computer.
    in_dim : int
        Dimension of the system input data.
    control_dim : int
        Dimension of the control input.
    dtype : type
        Data type of the reservoir computer (jnp.float64 is highly recommended).
    alpha_1 : float
        Weight for trajectory deviation penalty in control optimization.
    alpha_2 : float
        Weight for control magnitude penalty in control optimization.
    alpha_3 : float
        Weight for control derivative penalty in control optimization.
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    """
    self.driver = driver
    self.readout = readout
    self.embedding = embedding
    self.in_dim = in_dim
    self.control_dim = control_dim
    self.out_dim = self.readout.out_dim
    self.res_dim = self.driver.res_dim
    self.dtype = dtype
    self.alpha_1 = alpha_1
    self.alpha_2 = alpha_2
    self.alpha_3 = alpha_3
    self.seed = seed

force

force(in_seq: Array, control_seq: Array, res_state: Array) -> Array

Teacher forces the reservoir with input and control sequences.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to force the reservoir, (shape=(seq_len, in_dim)).

required
control_seq Array

Control sequence to force the reservoir, (shape=(seq_len, control_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

required

Returns:

Type Description
Array

Forced reservoir sequence, (shape=(seq_len, res_dim)).

Source code in src/orc/control/base.py
@eqx.filter_jit
def force(self, in_seq: Array, control_seq: Array, res_state: Array) -> Array:
    """Teacher forces the reservoir with input and control sequences.

    Parameters
    ----------
    in_seq: Array
        Input sequence to force the reservoir, (shape=(seq_len, in_dim)).
    control_seq: Array
        Control sequence to force the reservoir, (shape=(seq_len, control_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).

    Returns
    -------
    Array
        Forced reservoir sequence, (shape=(seq_len, res_dim)).
    """

    def scan_fn(state, in_vars):
        in_state, control_state = in_vars
        # Concatenate input and control for embedding
        combined_input = jnp.concatenate([in_state, control_state])
        proj_vars = self.embedding.embed(combined_input)
        res_state = self.driver.advance(proj_vars, state)
        return (res_state, res_state)

    _, res_seq = jax.lax.scan(scan_fn, res_state, (in_seq, control_seq))
    return res_seq

__call__

__call__(in_seq: Array, control_seq: Array, res_state: Array) -> Array

Teacher forces the reservoir, wrapper for force method.

Parameters:

Name Type Description Default
in_seq Array

Input sequence to force the reservoir, (shape=(seq_len, in_dim)).

required
control_seq Array

Control sequence to force the reservoir, (shape=(seq_len, control_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

required

Returns:

Type Description
Array

Forced reservoir sequence, (shape=(seq_len, res_dim)).

Source code in src/orc/control/base.py
def __call__(self, in_seq: Array, control_seq: Array, res_state: Array) -> Array:
    """Teacher forces the reservoir, wrapper for `force` method.

    Parameters
    ----------
    in_seq: Array
        Input sequence to force the reservoir, (shape=(seq_len, in_dim)).
    control_seq: Array
        Control sequence to force the reservoir, (shape=(seq_len, control_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).

    Returns
    -------
    Array
        Forced reservoir sequence, (shape=(seq_len, res_dim)).
    """
    return self.force(in_seq, control_seq, res_state)

apply_control

apply_control(control_seq: Array, res_state: Array) -> tuple[Array, Array]

Apply a predefined control sequence in closed-loop.

The readout feeds back as the next input: u(t+1) = readout(x(t)). Control c(t) comes from the provided control_seq.

Parameters:

Name Type Description Default
control_seq Array

Control sequence to apply, (shape=(fcast_len, control_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

required

Returns:

Type Description
Array

Controlled output trajectory with shape=(fcast_len, out_dim)).

Source code in src/orc/control/base.py
@eqx.filter_jit
def apply_control(
    self, control_seq: Array, res_state: Array
) -> tuple[Array, Array]:
    """Apply a predefined control sequence in closed-loop.

    The readout feeds back as the next input: u(t+1) = readout(x(t)).
    Control c(t) comes from the provided control_seq.

    Parameters
    ----------
    control_seq : Array
        Control sequence to apply, (shape=(fcast_len, control_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).

    Returns
    -------
    Array
        Controlled output trajectory with shape=(fcast_len, out_dim)).
    """

    def scan_fn(state, control_vars):
        # Get output from current reservoir state
        out_state = self.readout(state)
        # Concatenate output (as next input) with control
        combined_input = jnp.concatenate([out_state, control_vars])
        # Embed and advance reservoir
        proj_vars = self.embedding(combined_input)
        next_res_state = self.driver(proj_vars, state)
        return (next_res_state, self.readout(next_res_state))

    res_state, state_seq = jax.lax.scan(scan_fn, res_state, control_seq)
    return state_seq

set_readout

set_readout(readout: ReadoutBase) -> RCControllerBase

Replace readout layer.

Parameters:

Name Type Description Default
readout ReadoutBase

New readout layer.

required

Returns:

Type Description
RCControllerBase

Updated model with new readout layer.

Source code in src/orc/control/base.py
def set_readout(self, readout: ReadoutBase) -> "RCControllerBase":
    """Replace readout layer.

    Parameters
    ----------
    readout : ReadoutBase
        New readout layer.

    Returns
    -------
    RCControllerBase
        Updated model with new readout layer.
    """

    def where(m: "RCControllerBase"):
        return m.readout

    new_model = eqx.tree_at(where, self, readout)
    return new_model

set_embedding

set_embedding(embedding: EmbedBase) -> RCControllerBase

Replace embedding layer.

Parameters:

Name Type Description Default
embedding EmbedBase

New embedding layer.

required

Returns:

Type Description
RCControllerBase

Updated model with new embedding layer.

Source code in src/orc/control/base.py
def set_embedding(self, embedding: EmbedBase) -> "RCControllerBase":
    """Replace embedding layer.

    Parameters
    ----------
    embedding : EmbedBase
        New embedding layer.

    Returns
    -------
    RCControllerBase
        Updated model with new embedding layer.
    """

    def where(m: "RCControllerBase"):
        return m.embedding

    new_model = eqx.tree_at(where, self, embedding)
    return new_model

compute_penalty

compute_penalty(control_seq: Array, res_state: Array, ref_traj: Array) -> Float

Compute the control penalty for a given control sequence.

The penalty consists of three terms: - Deviation penalty: squared error between forecast and reference trajectory - Magnitude penalty: squared norm of control inputs - Derivative penalty: squared norm of control input differences

Parameters:

Name Type Description Default
control_seq Array

Control sequence to evaluate, (shape=(fcast_len, control_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

required
ref_traj Array

Reference trajectory to track, (shape=(fcast_len, out_dim)).

required

Returns:

Type Description
Float

Total penalty value (scalar).

Source code in src/orc/control/base.py
def compute_penalty(
    self,
    control_seq: Array,
    res_state: Array,
    ref_traj: Array,
) -> Float:
    """Compute the control penalty for a given control sequence.

    The penalty consists of three terms:
    - Deviation penalty: squared error between forecast and reference trajectory
    - Magnitude penalty: squared norm of control inputs
    - Derivative penalty: squared norm of control input differences

    Parameters
    ----------
    control_seq : Array
        Control sequence to evaluate, (shape=(fcast_len, control_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).
    ref_traj : Array
        Reference trajectory to track, (shape=(fcast_len, out_dim)).

    Returns
    -------
    Float
        Total penalty value (scalar).
    """
    fcast = self.apply_control(control_seq, res_state)
    deviation = fcast - ref_traj
    dev_penalty = jnp.sum(deviation**2) * self.alpha_1
    mag_penalty = jnp.sum(control_seq**2) * self.alpha_2
    deriv_penalty = jnp.sum(jnp.diff(control_seq, axis=0) ** 2) * self.alpha_3
    return dev_penalty + mag_penalty + deriv_penalty

compute_control

compute_control(control_seq: Array, res_state: Array, ref_traj: Array) -> Array

Compute optimal control sequence to track a reference trajectory.

Uses BFGS optimization to find a control sequence that minimizes the penalty function (deviation from reference + control magnitude + control smoothness).

Parameters:

Name Type Description Default
control_seq Array

Initial guess for control sequence, (shape=(fcast_len, control_dim)).

required
res_state Array

Initial reservoir state, (shape=(res_dim,)).

required
ref_traj Array

Reference trajectory to track, (shape=(fcast_len, out_dim)).

required

Returns:

Type Description
Array

Optimized control sequence, (shape=(fcast_len, control_dim)).

Source code in src/orc/control/base.py
@eqx.filter_jit
def compute_control(
    self,
    control_seq: Array,
    res_state: Array,
    ref_traj: Array,
) -> Array:
    """Compute optimal control sequence to track a reference trajectory.

    Uses BFGS optimization to find a control sequence that minimizes the
    penalty function (deviation from reference + control magnitude + control
    smoothness).

    Parameters
    ----------
    control_seq : Array
        Initial guess for control sequence, (shape=(fcast_len, control_dim)).
    res_state : Array
        Initial reservoir state, (shape=(res_dim,)).
    ref_traj : Array
        Reference trajectory to track, (shape=(fcast_len, out_dim)).

    Returns
    -------
    Array
        Optimized control sequence, (shape=(fcast_len, control_dim)).
    """

    def loss_fn(control_seq):
        control_seq = control_seq.reshape(-1, self.control_dim)
        return self.compute_penalty(control_seq, res_state, ref_traj)

    # TODO: Implement optimization to allow finer grained control of tolerances
    # linesearch = optax.scale_by_backtracking_linesearch(
    #     max_backtracking_steps=30,
    #     decrease_factor=0.5,
    # )
    # solver = optax.lbfgs(linesearch=linesearch)

    # if not use_builtin_solver:
    #     solver = optax.lbfgs()

    #     @jax.jit
    #     def run_lbfgs(x0):
    #         value_and_grad_fn = jax.value_and_grad(loss_fn)

    #         def step(carry, _):
    #             x, state = carry
    #             value, grad = value_and_grad_fn(x)
    #             updates, state = solver.update(
    #                                              grad,
    #                                              state,
    #                                               x,
    #                                               value=value,
    #                                               grad=grad,
    #                                               value_fn=loss_fn)
    #             x = optax.apply_updates(x, updates)
    #             return (x, state), None

    #         init_state = solver.init(x0)
    #         (x_final, final_state), _ = jax.lax.scan(step, (x0, init_state),
    #                                                   None,
    #                                                   length=max_iter)
    #         return x_final, final_state

    #     control_opt, state = run_lbfgs(control_seq)

    optimize_results = jax.scipy.optimize.minimize(
        loss_fn, control_seq.reshape(-1), method="BFGS"
    )
    control_opt = optimize_results.x.reshape(-1, self.control_dim)

    return control_opt

Models

ESNController

ESNController(data_dim: int, control_dim: int, res_dim: int, leak_rate: float = 0.6, bias: float = 1.6, embedding_scaling: float = 0.08, Wr_density: float = 0.02, Wr_spectral_radius: float = 0.8, dtype: type = float64, alpha_1: float = 100, alpha_2: float = 1, alpha_3: float = 5, seed: int = 0, quadratic: bool = False, use_sparse_eigs: bool = True)

Bases: RCControllerBase

Basic implementation of ESN for control tasks.

Attributes:

Name Type Description
res_dim int

Reservoir dimension.

data_dim int

System input/output dimension.

control_dim int

Control input dimension.

driver ESNDriver

Driver implementing the Echo State Network dynamics.

readout ReadoutBase

Trainable linear readout layer.

embedding LinearEmbedding

Untrainable linear embedding layer for [input, control] concatenation.

alpha_1 float

Weight for trajectory deviation penalty in control optimization.

alpha_2 float

Weight for control magnitude penalty in control optimization.

alpha_3 float

Weight for control derivative penalty in control optimization.

Methods:

Name Description
force

Teacher forces the reservoir with input and control sequences.

apply_control

Apply a predefined control sequence in closed-loop.

set_readout

Replace readout layer.

set_embedding

Replace embedding layer.

Initialize the ESN controller.

Parameters:

Name Type Description Default
data_dim int

Dimension of the system input/output data.

required
control_dim int

Dimension of the control input.

required
res_dim int

Dimension of the reservoir adjacency matrix Wr.

required
leak_rate float

Integration leak rate of the reservoir dynamics.

0.6
bias float

Bias term for the reservoir dynamics.

1.6
embedding_scaling float

Scaling factor for the embedding layer.

0.08
Wr_density float

Density of the reservoir adjacency matrix Wr.

0.02
Wr_spectral_radius float

Largest eigenvalue of the reservoir adjacency matrix Wr.

0.8
dtype type

Data type of the model (jnp.float64 is highly recommended).

float64
alpha_1 float

Weight for trajectory deviation penalty in control optimization.

100
alpha_2 float

Weight for control magnitude penalty in control optimization.

1
alpha_3 float

Weight for control derivative penalty in control optimization.

5
seed int

Random seed for generating the PRNG key for the reservoir computer.

0
quadratic bool

Use quadratic nonlinearity in output, default False.

False
use_sparse_eigs bool

Whether to use sparse eigensolver for setting the spectral radius of wr. Default is True, which is recommended to save memory and compute time. If False, will use dense eigensolver which may be more accurate.

True
Source code in src/orc/control/models.py
def __init__(
    self,
    data_dim: int,
    control_dim: int,
    res_dim: int,
    leak_rate: float = 0.6,
    bias: float = 1.6,
    embedding_scaling: float = 0.08,
    Wr_density: float = 0.02,
    Wr_spectral_radius: float = 0.8,
    dtype: type = jnp.float64,
    alpha_1: float = 100,
    alpha_2: float = 1,
    alpha_3: float = 5,
    seed: int = 0,
    quadratic: bool = False,
    use_sparse_eigs: bool = True,
) -> None:
    """
    Initialize the ESN controller.

    Parameters
    ----------
    data_dim : int
        Dimension of the system input/output data.
    control_dim : int
        Dimension of the control input.
    res_dim : int
        Dimension of the reservoir adjacency matrix Wr.
    leak_rate : float
        Integration leak rate of the reservoir dynamics.
    bias : float
        Bias term for the reservoir dynamics.
    embedding_scaling : float
        Scaling factor for the embedding layer.
    Wr_density : float
        Density of the reservoir adjacency matrix Wr.
    Wr_spectral_radius : float
        Largest eigenvalue of the reservoir adjacency matrix Wr.
    dtype : type
        Data type of the model (jnp.float64 is highly recommended).
    alpha_1 : float
        Weight for trajectory deviation penalty in control optimization.
    alpha_2 : float
        Weight for control magnitude penalty in control optimization.
    alpha_3 : float
        Weight for control derivative penalty in control optimization.
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    quadratic : bool
        Use quadratic nonlinearity in output, default False.
    use_sparse_eigs : bool
        Whether to use sparse eigensolver for setting the spectral radius of wr.
        Default is True, which is recommended to save memory and compute time. If
        False, will use dense eigensolver which may be more accurate.
    """
    # Initialize the random key and reservoir dimension
    self.res_dim = res_dim
    self.seed = seed
    self.data_dim = data_dim
    self.control_dim = control_dim
    key = jax.random.PRNGKey(seed)
    key_driver, key_readout, key_embedding = jax.random.split(key, 3)

    # Embedding accepts concatenated [input, control] vectors
    # Total input dimension is data_dim + control_dim
    embedding = LinearEmbedding(
        in_dim=data_dim + control_dim,
        res_dim=res_dim,
        seed=key_embedding[0],
        scaling=embedding_scaling,
    )
    driver = ESNDriver(
        res_dim=res_dim,
        seed=key_driver[0],
        leak=leak_rate,
        bias=bias,
        density=Wr_density,
        spectral_radius=Wr_spectral_radius,
        dtype=dtype,
        use_sparse_eigs=use_sparse_eigs,
    )
    if quadratic:
        readout = QuadraticReadout(
            out_dim=data_dim, res_dim=res_dim, seed=key_readout[0]
        )
    else:
        readout = LinearReadout(
            out_dim=data_dim, res_dim=res_dim, seed=key_readout[0]
        )

    super().__init__(
        driver=driver,
        readout=readout,
        embedding=embedding,
        in_dim=data_dim,
        control_dim=control_dim,
        dtype=dtype,
        seed=seed,
        alpha_1=alpha_1,
        alpha_2=alpha_2,
        alpha_3=alpha_3
    )

Training Functions

train_ESNController

train_ESNController(model: ESNController, train_seq: Array, control_seq: Array, target_seq: Array = None, spinup: int = 0, initial_res_state: Array = None, beta: float = 8e-08) -> tuple[ESNController, Array]

Training function for ESNController.

Parameters:

Name Type Description Default
model ESNController

ESNController model to train.

required
train_seq Array

Training input sequence for reservoir, (shape=(seq_len, data_dim)).

required
control_seq Array

Control input sequence for reservoir, (shape=(seq_len, control_dim)).

required
target_seq Array

Target sequence for training reservoir, (shape=(seq_len, data_dim)). If None, defaults to train_seq[1:].

None
initial_res_state Array

Initial reservoir state, (shape=(res_dim,)).

None
spinup int

Initial transient of reservoir states to discard.

0
beta float

Tikhonov regularization parameter.

8e-08

Returns:

Name Type Description
model ESNController

Trained ESN controller model.

res_seq Array

Training sequence of reservoir states.

Source code in src/orc/control/train.py
def train_ESNController(
    model: ESNController,
    train_seq: Array,
    control_seq: Array,
    target_seq: Array = None,
    spinup: int = 0,
    initial_res_state: Array = None,
    beta: float = 8e-8,
) -> tuple[ESNController, Array]:
    """Training function for ESNController.

    Parameters
    ----------
    model : ESNController
        ESNController model to train.
    train_seq : Array
        Training input sequence for reservoir, (shape=(seq_len, data_dim)).
    control_seq : Array
        Control input sequence for reservoir, (shape=(seq_len, control_dim)).
    target_seq : Array
        Target sequence for training reservoir, (shape=(seq_len, data_dim)).
        If None, defaults to train_seq[1:].
    initial_res_state : Array
        Initial reservoir state, (shape=(res_dim,)).
    spinup : int
        Initial transient of reservoir states to discard.
    beta : float
        Tikhonov regularization parameter.

    Returns
    -------
    model : ESNController
        Trained ESN controller model.
    res_seq : Array
        Training sequence of reservoir states.
    """
    # Check that model is an ESN Controller
    if not isinstance(model, ESNController):
        raise TypeError("Model must be an ESNController.")

    # check that train_seq and control_seq have the same length
    if train_seq.shape[0] != control_seq.shape[0]:
        raise ValueError("train_seq and control_seq must have the same length.")

    # check that spinup is less than the length of the training sequence
    if spinup >= train_seq.shape[0]:
        raise ValueError(
            "spinup must be less than the length of the training sequence."
        )

    if initial_res_state is None:
        initial_res_state = jnp.zeros(model.res_dim, dtype=model.dtype)

    if target_seq is None:
        tot_seq = train_seq
        tot_control_seq = control_seq
        target_seq = train_seq[1:, :]
        train_seq = train_seq[:-1, :]
        control_seq = control_seq[:-1, :]
    else:
        tot_seq = jnp.vstack((train_seq, target_seq[-1:]))
        tot_control_seq = control_seq

    tot_res_seq = model.force(tot_seq, tot_control_seq, initial_res_state)
    res_seq = tot_res_seq[:-1]
    if isinstance(model.readout, NonlinearReadout):
        res_seq_train = model.readout.nonlinear_transform(res_seq)
    else:
        res_seq_train = res_seq

    cmat = ridge_regression(res_seq_train[spinup:], target_seq[spinup:], beta)
    cmat = cmat.reshape(1, cmat.shape[0], cmat.shape[1])

    def where(m):
        return m.readout.wout

    model = eqx.tree_at(where, model, cmat)

    return model, tot_res_seq