Skip to content

Readouts API Reference

This module contains readout layers that map reservoir states to outputs.

Readout Base Class

ReadoutBase

ReadoutBase(out_dim, res_dim, dtype=float64)

Bases: Module, ABC

Base class dictating API for all implemented readout layers.

Attributes:

Name Type Description
out_dim int

Dimension of reservoir output.

res_dim int

Reservoir dimension.

dtype Float

Dtype of JAX arrays, jnp.float32 or jnp.float64.

Methods:

Name Description
readout

Map from reservoir state to output state.

batch_readout

Map from reservoir states to output states.

Ensure in dim, res dim, and dtype are correct type.

Source code in src/orc/readouts.py
def __init__(self, out_dim, res_dim, dtype=jnp.float64):
    """Ensure in dim, res dim, and dtype are correct type."""
    self.res_dim = res_dim
    self.out_dim = out_dim
    self.dtype = dtype
    if not isinstance(res_dim, int):
        raise TypeError("Reservoir dimension res_dim must be an integer.")
    if not isinstance(out_dim, int):
        raise TypeError("Reservoir dimension res_dim must be an integer.")
    self.dtype = dtype
    if not (dtype == jnp.float64 or dtype == jnp.float32):
        raise TypeError("dtype must be jnp.float64 of jnp.float32.")

readout abstractmethod

readout(res_state: Array) -> Array

Readout from reservoir state.

Parameters:

Name Type Description Default
res_state Array

Reservoir state.

required

Returns:

Type Description
Array

Output from reservoir state.

Source code in src/orc/readouts.py
@abstractmethod
def readout(
    self,
    res_state: Array,
) -> Array:
    """Readout from reservoir state.

    Parameters
    ----------
    res_state : Array
        Reservoir state.

    Returns
    -------
    Array
        Output from reservoir state.
    """
    pass

batch_readout

batch_readout(res_state: Array) -> Array

Batch apply readout from reservoir states.

Parameters:

Name Type Description Default
res_state Array

Reservoir state.

required

Returns:

Type Description
Array

Output from reservoir states.

Source code in src/orc/readouts.py
def batch_readout(
    self,
    res_state: Array,
) -> Array:
    """Batch apply readout from reservoir states.

    Parameters
    ----------
    res_state : Array
        Reservoir state.

    Returns
    -------
    Array
        Output from reservoir states.
    """
    return eqx.filter_vmap(self.readout)(res_state)

__call__

__call__(res_state: Array) -> Array

Call either readout or batch_readout depending on dimensions.

Parameters:

Name Type Description Default
res_state Array

Reservoir state, (shape=(chunks, res_dim) or shape=(seq_len, chunks, res_dim)).

required

Returns:

Type Description
Array

Output state, (out_dim,) or shape=(seq_len, out_dim)).

Source code in src/orc/readouts.py
def __call__(
    self,
    res_state: Array,
) -> Array:
    """Call either readout or batch_readout depending on dimensions.

    Parameters
    ----------
    res_state : Array
        Reservoir state, (shape=(chunks, res_dim) or
        shape=(seq_len, chunks, res_dim)).

    Returns
    -------
    Array
        Output state, (out_dim,) or shape=(seq_len, out_dim)).
    """
    if self.chunks > 0:
        if len(res_state.shape) == 2:
            to_ret = self.readout(res_state)
        elif len(res_state.shape) == 3:
            to_ret = self.batch_readout(res_state)
    else:
        if len(res_state.shape) == 1:
            to_ret = self.readout(res_state)
        elif len(res_state.shape) == 2:
            to_ret = self.batch_readout(res_state)
    return to_ret

Linear Readout

ParallelLinearReadout

ParallelLinearReadout(out_dim: int, res_dim: int, chunks: int = 1, dtype: Float = float64, *, seed: int = 0)

Bases: ReadoutBase

Linear readout layer.

Attributes:

Name Type Description
out_dim int

Dimension of reservoir output.

res_dim int

Reservoir dimension.

chunks int

Number of parallel reservoirs.

wout Array

Output matrix.

dtype Float

Dtype, default jnp.float64.

Methods:

Name Description
readout

Map from reservoir state to output state.

Initialize readout layer to zeros.

Parameters:

Name Type Description Default
out_dim int

Dimension of reservoir output.

required
res_dim int

Reservoir dimension.

required
chunks int

Number of parallel resrevoirs.

1
dtype Float

Dtype, default jnp.float64.

float64
seed int

Not used for ParallelLinearReadout, here to maintain consistent interface.

0
Source code in src/orc/readouts.py
def __init__(
    self,
    out_dim: int,
    res_dim: int,
    chunks: int = 1,
    dtype: Float = jnp.float64,
    *,
    seed: int = 0,
) -> None:
    """Initialize readout layer to zeros.

    Parameters
    ----------
    out_dim : int
        Dimension of reservoir output.
    res_dim : int
        Reservoir dimension.
    chunks : int
        Number of parallel resrevoirs.
    dtype : Float
        Dtype, default jnp.float64.
    seed : int
        Not used for ParallelLinearReadout, here to maintain consistent
        interface.
    """
    super().__init__(out_dim=out_dim, res_dim=res_dim, dtype=dtype)
    self.out_dim = out_dim
    self.res_dim = res_dim
    self.wout = jnp.zeros((chunks, int(out_dim / chunks), res_dim), dtype=dtype)
    self.dtype = dtype
    self.chunks = chunks

readout

readout(res_state: Array) -> Array

Readout from reservoir state.

Parameters:

Name Type Description Default
res_state Array

Reservoir state, (shape=(chunks, res_dim,)).

required

Returns:

Type Description
Array

Output from reservoir, (shape=(out_dim,)).

Source code in src/orc/readouts.py
@eqx.filter_jit
def readout(self, res_state: Array) -> Array:
    """Readout from reservoir state.

    Parameters
    ----------
    res_state : Array
        Reservoir state, (shape=(chunks, res_dim,)).

    Returns
    -------
    Array
        Output from reservoir, (shape=(out_dim,)).
    """
    if res_state.shape[1] != self.res_dim:
        raise ValueError(
            "Incorrect reservoir dimension for instantiated output map."
        )
    return jnp.ravel(eqx.filter_vmap(jnp.matmul)(self.wout, res_state))

Nonlinear Readout

ParallelNonlinearReadout

ParallelNonlinearReadout(out_dim: int, res_dim: int, nonlin_list: list[Callable], chunks: int = 1, dtype: Float = float64, *, seed: int = 0)

Bases: ReadoutBase

Readout layer with user specified nonlinearities.

Attributes:

Name Type Description
out_dim int

Dimension of reservoir output.

res_dim int

Reservoir dimension.

chunks int

Number of parallel reservoirs.

wout Array

Output matrix.

nonlin_list list

List containing user specified nonlinearities.

dtype Float

Dtype, default jnp.float64.

Methods:

Name Description
nonlinear_transform

Nonlinear transform that acts entrywise on reservoir state.

readout

Map from reservoir state to output state.

__call__

Map from reservoir state to output state, handles batch and single outputs.

Initialize readout layer to zeros.

Parameters:

Name Type Description Default
out_dim int

Dimension of reservoir output.

required
res_dim int

Reservoir dimension.

required
nonlin_list list[Callable]

List containing user specified entrywise nonlinearities. Each entry should be a function mapping a scalar value to another scalar value, e.g. lambda x : x ** 2 or lambda x : jnp.sin(x).

required
chunks int

Number of parallel reservoirs.

1
dtype Float

Dtype, default jnp.float64.

float64
seed int

Not used for ParallelNonlinearReadout, here to maintain consistent interface.

0
Source code in src/orc/readouts.py
def __init__(
    self,
    out_dim: int,
    res_dim: int,
    nonlin_list: list[Callable],
    chunks: int = 1,
    dtype: Float = jnp.float64,
    *,
    seed: int = 0,
) -> None:
    """Initialize readout layer to zeros.

    Parameters
    ----------
    out_dim : int
        Dimension of reservoir output.
    res_dim : int
        Reservoir dimension.
    nonlin_list : list[Callable]
        List containing user specified entrywise nonlinearities. Each entry should
        be a function mapping a scalar value to another scalar value, e.g.
        lambda x : x ** 2 or lambda x : jnp.sin(x).
    chunks : int
        Number of parallel reservoirs.
    dtype : Float
        Dtype, default jnp.float64.
    seed : int
        Not used for ParallelNonlinearReadout, here to maintain consistent
        interface.
    """
    super().__init__(out_dim=out_dim, res_dim=res_dim, dtype=dtype)
    self.out_dim = out_dim
    self.res_dim = res_dim
    self.wout = jnp.zeros((chunks, int(out_dim / chunks), res_dim), dtype=dtype)
    self.dtype = dtype
    self.chunks = chunks
    self.nonlin_list = nonlin_list

nonlinear_transform

nonlinear_transform(res_state: Array) -> Array

Perform nonlinear transformation on reservoir state.

Let tot_list be the list consisting of nonlin_list prepended by the identity mapping. Let n be the length of tot_list. Then, nonlinear_transform acts such that for all 0 <= k < chunks and 0 <= j < j * n: res_state[k, j * n] <- res_state[k, j*n] res_state[k, j * n + 1] <- f_0(res_state[k, j * n + 1]) ... res_state[k, j * n + n - 1] <- f_{n-1}(res_state[k, j * n + n - 1]) where f_i is the i-th entry of nonlin_list.

Parameters:

Name Type Description Default
res_state Array

Reservoir state, (shape=(..., res_dim,)).

required

Returns:

Type Description
Array

Transformed reservoir state.

Source code in src/orc/readouts.py
def nonlinear_transform(self, res_state: Array) -> Array:
    """Perform nonlinear transformation on reservoir state.

    Let tot_list be the list consisting of nonlin_list prepended by the identity
    mapping. Let n be the length of tot_list. Then, nonlinear_transform acts such
    that for all 0 <= k < chunks and 0 <= j < j * n:
    res_state[k, j * n] <- res_state[k, j*n]
    res_state[k, j * n + 1] <- f_0(res_state[k, j * n + 1])
    ...
    res_state[k, j * n + n - 1] <- f_{n-1}(res_state[k, j * n + n - 1])
    where f_i is the i-th entry of nonlin_list.

    Parameters
    ----------
    res_state : Array
        Reservoir state, (shape=(..., res_dim,)).

    Returns
    -------
    Array
        Transformed reservoir state.
    """
    num_nonlins = len(self.nonlin_list)
    transformed_res_state = res_state
    for idx in range(num_nonlins):
        transformed_res_state = transformed_res_state.at[
            ..., idx + 1 :: num_nonlins + 1
        ].set(
            self.nonlin_list[idx](
                transformed_res_state[..., idx + 1 :: num_nonlins + 1]
            )
        )
    return transformed_res_state

readout

readout(res_state: Array) -> Array

Readout from reservoir state.

Parameters:

Name Type Description Default
res_state Array

Reservoir state, (shape=(chunks, res_dim,)).

required

Returns:

Type Description
Array

Output from reservoir, (shape=(out_dim,)).

Source code in src/orc/readouts.py
@eqx.filter_jit
def readout(self, res_state: Array) -> Array:
    """Readout from reservoir state.

    Parameters
    ----------
    res_state : Array
        Reservoir state, (shape=(chunks, res_dim,)).

    Returns
    -------
    Array
        Output from reservoir, (shape=(out_dim,)).
    """
    if res_state.shape[1] != self.res_dim:
        raise ValueError(
            "Incorrect reservoir dimension for instantiated output map."
        )
    transformed_res_state = self.nonlinear_transform(res_state)
    return jnp.ravel(eqx.filter_vmap(jnp.matmul)(self.wout, transformed_res_state))

Quadratic Readout

ParallelQuadraticReadout

ParallelQuadraticReadout(out_dim: int, res_dim: int, chunks: int = 1, dtype: Float = float64, *, seed: int = 0)

Bases: ParallelNonlinearReadout

Quadratic readout layer.

Attributes:

Name Type Description
out_dim int

Dimension of reservoir output.

res_dim int

Reservoir dimension.

chunks int

Number of parallel reservoirs.

wout Array

Output matrix.

dtype Float

Dtype, default jnp.float64.

Methods:

Name Description
nonlinear_transform

Quadratic transform that acts entrywise on reservoir state.

readout

Map from reservoir state to output state with quadratic nonlinearity.

__call__

Map from reservoir state to output state with quadratic nonlinearity, handles batch and single outputs.

Initialize readout layer to zeros.

Parameters:

Name Type Description Default
out_dim int

Dimension of reservoir output.

required
res_dim int

Reservoir dimension.

required
chunks int

Number of parallel resrevoirs.

1
dtype Float

Dtype, default jnp.float64.

float64
seed int

Not used for ParallelQuadraticReadout, here to maintain consistent interface.

0
Source code in src/orc/readouts.py
def __init__(
    self,
    out_dim: int,
    res_dim: int,
    chunks: int = 1,
    dtype: Float = jnp.float64,
    *,
    seed: int = 0,
) -> None:
    """Initialize readout layer to zeros.

    Parameters
    ----------
    out_dim : int
        Dimension of reservoir output.
    res_dim : int
        Reservoir dimension.
    chunks : int
        Number of parallel resrevoirs.
    dtype : Float
        Dtype, default jnp.float64.
    seed : int
        Not used for ParallelQuadraticReadout, here to maintain consistent
        interface.
    """
    super().__init__(
        out_dim=out_dim,
        res_dim=res_dim,
        dtype=dtype,
        nonlin_list=[lambda x: x**2],
        chunks=chunks,
    )