Embedding Layers¶
Overview¶
The embedding layer serves as the interface between your input data and the reservoir computer. It transforms input vectors of dimension in_dim into embedded vectors of dimension res_dim that can be processed by the reservoir driver.
All embedding layers inherit from the EmbedBase abstract class, ensuring a consistent API across different implementations.
Base Class: EmbedBase¶
The EmbedBase class defines the core interface that all embedding implementations must follow:
Key Attributes¶
in_dim: Input dimensionres_dim: Reservoir dimensiondtype: JAX array dtype (jnp.float32 or jnp.float64)
Core Methods¶
embed(in_state): Transform a single input vectorbatch_embed(in_state): Transform multiple input vectors efficiently__call__(in_state): Flexible interface supporting both single and batch inputs
ParallelLinearEmbedding¶
The ParallelLinearEmbedding class provides matrix multiplication-based embedding with support for parallel reservoirs to handle spatial/high-dimensional data.
Basic Usage¶
import jax.numpy as jnp
from orc.embeddings import ParallelLinearEmbedding
# Create a simple linear embedding
embedding = ParallelLinearEmbedding(
in_dim=10, # Input dimension
res_dim=100, # Reservoir dimension
scaling=0.1, # Input matrix values range [-0.1, 0.1]
seed=42 # Random seed for reproducibility
)
# Embed a single input
input_vector = jnp.ones(10)
embedded = embedding(input_vector) # Shape: (1, 100) for single reservoir
Parallel Reservoirs¶
ParallelLinearEmbedding supports decomposing high-dimensional inputs into multiple (overlapping) parallel reservoirs:
# Parallel reservoirs for spatial data
embedding = ParallelLinearEmbedding(
in_dim=1000, # Total spatial dimension
res_dim=200, # Reservoir dimension per chunk
scaling=0.1,
chunks=50, # Create 4 parallel reservoirs
locality=2, # Each reservoir sees 2 neighbors on each side
# in addition to in_dim / chunks values
periodic=True, # Use periodic boundary conditions
seed=42
)
# Input gets decomposed into overlapping chunks
spatial_input = jnp.sin(jnp.linspace(0, 4*jnp.pi, 1000))
embedded = embedding(spatial_input) # Shape: (4, 200)
Key Parameters¶
- scaling: Controls the magnitude of random input weights (range [-scaling, scaling])
- chunks: Number of parallel reservoirs to create
- locality: Number of neighboring points each reservoir can see
- periodic: Whether to use periodic boundary conditions for spatial decomposition
Spatial Decomposition¶
When chunks > 1, the embedding automatically:
- Localizes: Splits the input into overlapping chunks based on
locality - Handles boundaries: Uses periodic or extended boundary conditions
- Embeds: Applies linear transformation to each chunk independently
This is particularly useful for: - Spatiotemporal systems (PDEs, cellular automata)
- Problems with translational symmetry
Boundary Conditions¶
- periodic=True: Connects end and beginning smoothly (default)
- periodic=False: Extends edge values into locality regions
Custom Embeddings¶
To create your own embedding layer:
from orc.embeddings import EmbedBase
import equinox as eqx
class CustomEmbedding(EmbedBase):
def __init__(self, in_dim, res_dim, **kwargs):
super().__init__(in_dim, res_dim)
# Initialize your parameters
def embed(self, in_state):
# Implement your embedding logic
# Return shape depends on parallel reservoir support
pass
Requirements¶
- Inherit from
EmbedBase - Implement the abstract
embed()method - Choose your parallel reservoir support level:
Single Reservoir Only¶
For embeddings that don't support parallel reservoirs:
def embed(self, in_state):
# Transform input to reservoir dimension
embedded = your_transformation(in_state)
return embedded # Shape: (res_dim,)
Parallel Reservoir Support¶
For embeddings that support parallel reservoirs like ParallelLinearEmbedding:
def embed(self, in_state):
# Transform input with parallel processing
embedded = your_parallel_transformation(in_state)
return embedded # Shape: (chunks, res_dim)
Shape Summary¶
- Single reservoir:
embed()returns(res_dim,) - Parallel reservoirs:
embed()returns(chunks, res_dim) - The base class
__call__()method handles single reservoirs automatically - For parallel reservoir support, override
__call__()to handle shape processing
Design Considerations¶
Memory and Performance¶
- Use
chunks > 1for large spatial systems to reduce memory usage - The
batch_embedmethod is JIT-compiled for efficient batch processing - All operations use JAX for automatic differentiation and GPU acceleration
Integration with RC Models¶
- Embedding output feeds directly into reservoir drivers
- The
chunksdimension corresponds to parallel reservoir processing - Shape consistency:
embedding_output.shape[0]must match driver expectations