Skip to content

Embeddings API Reference

This module contains input embedding layers for preprocessing data before feeding into the reservoir.

Embedding Base Class

EmbedBase

EmbedBase(in_dim, res_dim, dtype=float64)

Bases: Module, ABC

Base class dictating API for all implemented embedding layers.

Attributes:

Name Type Description
in_dim int

Input dimension.

res_dim int

Reservoir dimension.

dtype Float

Dtype of JAX arrays, jnp.float32 or jnp.float64.

Methods:

Name Description
embed

Embed input into reservoir dimension.

batch_embed

Embed multiple inputs into reservoir dimension.

Ensure in dim, res dim, and dtype are correct type.

Source code in src/orc/embeddings.py
def __init__(self, in_dim, res_dim, dtype=jnp.float64):
    """Ensure in dim, res dim,  and dtype are correct type."""
    self.res_dim = res_dim
    self.in_dim = in_dim
    self.dtype = dtype
    if not isinstance(res_dim, int):
        raise TypeError("Reservoir dimension res_dim must be an integer.")
    if not isinstance(in_dim, int):
        raise TypeError("Reservoir dimension res_dim must be an integer.")
    self.dtype = dtype
    if not (dtype == jnp.float64 or dtype == jnp.float32):
        raise TypeError("dtype must be jnp.float64 of jnp.float32.")

embed abstractmethod

embed(in_state: Array) -> Array

Embed input signal to reservoir dimension.

Parameters:

Name Type Description Default
in_state Array

Input state, (shape=(in_dim,)).

required

Returns:

Type Description
Array

Embedded input state to reservoir dimension.

Source code in src/orc/embeddings.py
@abstractmethod
def embed(
    self,
    in_state: Array,
) -> Array:
    """Embed input signal to reservoir dimension.

    Parameters
    ----------
    in_state : Array
        Input state, (shape=(in_dim,)).

    Returns
    -------
    Array
        Embedded input state to reservoir dimension.
    """
    pass

batch_embed

batch_embed(in_state: Array) -> Array

Batch apply embedding from input states.

Parameters:

Name Type Description Default
in_state Array

Input states.

required

Returns:

Type Description
Array

Embedded input states to reservoir, (shape=(batch_dim, res_dim,)).

Source code in src/orc/embeddings.py
@eqx.filter_jit
def batch_embed(
    self,
    in_state: Array,
) -> Array:
    """Batch apply embedding from input states.

    Parameters
    ----------
    in_state : Array
        Input states.

    Returns
    -------
    Array
        Embedded input states to reservoir, (shape=(batch_dim, res_dim,)).
    """
    return eqx.filter_vmap(self.embed)(in_state)

__call__

__call__(in_state: Array) -> Array

Embed state to reservoir dimensions.

Parameters:

Name Type Description Default
in_state Array

Input state, (shape=(in_dim,) or shape=(seq_len, in_dim)).

required

Returns:

Type Description
Array

Embedded input to reservoir, (shape=(chunks, res_dim,) or shape=(seq_len, chunks, res_dim)).

Source code in src/orc/embeddings.py
def __call__(
    self,
    in_state: Array,
) -> Array:
    """Embed state to reservoir dimensions.

    Parameters
    ----------
    in_state : Array
        Input state, (shape=(in_dim,) or shape=(seq_len, in_dim)).

    Returns
    -------
    Array
        Embedded input to reservoir, (shape=(chunks, res_dim,) or
        shape=(seq_len, chunks, res_dim)).
    """
    if len(in_state.shape) == 1:
        to_ret = self.embed(in_state)
    elif len(in_state.shape) == 2:
        to_ret = self.batch_embed(in_state)
    return to_ret

Linear Embedding

ParallelLinearEmbedding

ParallelLinearEmbedding(in_dim: int, res_dim: int, scaling: float, dtype: Float = float64, chunks: int = 1, locality: int = 0, periodic: bool = True, *, seed: int)

Bases: EmbedBase

Linear embedding layer.

Attributes:

Name Type Description
in_dim int

Reservoir input dimension.

res_dim int

Reservoir dimension.

scaling float

Min/max values of input matrix.

win Array

Input matrix.

chunks int

Number of parallel reservoirs.

locality int

Adjacent reservoir overlap.

periodic bool

Assume periodic BCs when decomposing the input state to parallel network inputs. If False, the input is padded with boundary values at the edges (i.e., edge values are extended to the locality region), which may not match the true spatial dynamics. If True, the input is padded by connecting smoothly the end and beginning of the signal ensuring continuity. Default is True.

Methods:

Name Description
__call__

Embed input state to reservoir dimension.

localize

Decompose input_state to parallel network inputs.

moving_window

Helper function for localize.

embed

Embed single input state to reservoir dimension.

Instantiate linear embedding.

Parameters:

Name Type Description Default
in_dim int

Input dimension to reservoir.

required
res_dim int

Reservoir dimension.

required
scaling float

Min/max values of input matrix.

required
seed int

Random seed for generating the PRNG key for the reservoir computer.

required
dtype Float

Dtype of model, jnp.float64 or jnp.float32.

float64
periodic bool

Assume periodic BCs when decomposing the input state to parallel network inputs. If False, the input is padded with boundary values at the edges (i.e., edge values are extended to the locality region), which may not match the true spatial dynamics. If True, the input is padded by connecting smoothly the end and beginning of the signal ensuring continuity. Default is True.

True
Source code in src/orc/embeddings.py
def __init__(
    self,
    in_dim: int,
    res_dim: int,
    scaling: float,
    dtype: Float = jnp.float64,
    chunks: int = 1,
    locality: int = 0,
    periodic: bool = True,
    *,
    seed: int,
) -> None:
    """Instantiate linear embedding.

    Parameters
    ----------
    in_dim : int
        Input dimension to reservoir.
    res_dim : int
        Reservoir dimension.
    scaling : float
        Min/max values of input matrix.
    seed : int
        Random seed for generating the PRNG key for the reservoir computer.
    dtype : Float
        Dtype of model, jnp.float64 or jnp.float32.
    periodic : bool
        Assume periodic BCs when decomposing the input state to parallel network
        inputs. If False, the input is padded with boundary values at the edges
        (i.e., edge values are extended to the locality region), which may not
        match the true spatial dynamics. If True, the input is padded by connecting
        smoothly the end and beginning of the signal ensuring continuity. Default is
        True.
    """
    super().__init__(in_dim=in_dim, res_dim=res_dim, dtype=dtype)
    self.scaling = scaling
    self.dtype = dtype
    key = jax.random.key(seed)
    self.chunk_size = int(in_dim / chunks)

    if in_dim % chunks:
        raise ValueError(
            f"The number of chunks {chunks} must evenly divide in_dim {in_dim}."
        )

    self.win = jax.random.uniform(
        key,
        (chunks, res_dim, self.chunk_size + 2 * locality),
        minval=-scaling,
        maxval=scaling,
        dtype=dtype,
    )
    self.locality = locality
    self.chunks = chunks
    self.periodic = periodic

moving_window

moving_window(a)

Generate window to compute localized states.

Source code in src/orc/embeddings.py
@eqx.filter_jit
def moving_window(self, a):
    """Generate window to compute localized states."""
    size = int(self.in_dim / self.chunks + 2 * self.locality)
    starts = jnp.arange(len(a) - size + 1)[: self.chunks] * int(
        self.in_dim / self.chunks
    )
    return eqx.filter_vmap(
        lambda start: jax.lax.dynamic_slice(a, (start,), (size,))
    )(starts)

localize

localize(in_state: Array) -> Array

Generate parallel reservoir inputs from input state.

Parameters:

Name Type Description Default
in_state Array

Input state, (shape=(in_dim,))

required

Returns:

Type Description
Array

Parallel reservoir inputs, (shape=(chunks, chunk_size + 2*locality))

Source code in src/orc/embeddings.py
@eqx.filter_jit
def localize(self, in_state: Array) -> Array:
    """Generate parallel reservoir inputs from input state.

    Parameters
    ----------
    in_state : Array
        Input state, (shape=(in_dim,))

    Returns
    -------
    Array
        Parallel reservoir inputs, (shape=(chunks, chunk_size + 2*locality))
    """
    if len(in_state.shape) != 1:
        raise ValueError(
            "Only 1-dimensional localization is currently supported, detected a "
            f"{len(in_state.shape)}D field."
        )
    aug_state = jnp.hstack(
        [in_state[-self.locality :], in_state, in_state[: self.locality]]
    )
    if not self.periodic:
        aug_state = aug_state.at[: self.locality].set(aug_state[self.locality])
        aug_state = aug_state.at[-self.locality :].set(aug_state[-self.locality])
    return self.moving_window(aug_state)

embed

embed(in_state: Array) -> Array

Embed single state to reservoir dimensions.

Parameters:

Name Type Description Default
in_state Array

Input state, (shape=(in_dim,)).

required

Returns:

Type Description
Array

Embedded input to reservoir, (shape=(chunks, res_dim,)).

Source code in src/orc/embeddings.py
@eqx.filter_jit
def embed(self, in_state: Array) -> Array:
    """Embed single state to reservoir dimensions.

    Parameters
    ----------
    in_state : Array
        Input state, (shape=(in_dim,)).

    Returns
    -------
    Array
        Embedded input to reservoir, (shape=(chunks, res_dim,)).
    """
    if in_state.shape != (self.in_dim,):
        raise ValueError("Incorrect dimension for input state.")
    localized_states = self.localize(in_state)

    return eqx.filter_vmap(jnp.matmul)(self.win, localized_states)