Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions history.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,6 @@ History
* new suite of competitive learning synapses (generalized formats), including vector-quantization, self-organizing map, adaptive resonance theory (contus-version), and modern hopfield network
* revisions to metric and model utils
* some additional clean-up, including supported retinal ganglion encoder
* fixed pkg-resource header (in both ngcsimlib and ngclearn) to account for newer python(s)


2 changes: 1 addition & 1 deletion ngclearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"currently installed!")

##################################################################################
## Needed to preload is called before anything in ngclearn
## Following are needed to preload is called before anything in ngclearn
from pathlib import Path
from sys import argv
import numpy
Expand Down
69 changes: 38 additions & 31 deletions ngclearn/components/input_encoders/ganglionCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def _create_patches(obs, patch_shape, step_shape):

class RetinalGanglionCell(JaxComponent):
"""
A group of retinal ganglion cell that senses the input stimuli and sends out the filtered signal to the brain.
A group of retinal ganglion cell that sense input stimuli and send out filtered
signals (as output). Note that these simulated cells employ internal generalized
filters based on either Gaussian or difference-of-Gaussian kernels) to recover
historical receptive field processing effects.

| --- Cell Input Compartments: ---
| inputs - input (takes in external signals)
Expand All @@ -85,32 +88,34 @@ class RetinalGanglionCell(JaxComponent):
filter_type: string name of filter function (Default: identity)
:Note: supported filters include "gaussian", "difference_of_gaussian"

sigma: standard deviation of gaussian kernel
sigma: standard deviation of (gaussian) kernel

area_shape: receptive field area of ganglion cells in this module all together
area_shape: shape of receptive field area of ganglion cells in this module (all together)

n_cells: number of ganglion cells in this module

patch_shape: each ganglion cell receptive field area
patch_shape: shape of each ganglion cell's receptive field area

step_shape: the non-overlapping area between each two ganglion cells
step_shape: the non-overlapping area between each pair (two) of ganglion cells

batch_size: batch size dimension of this cell (Default: 1)
batch_size: batch size dimension of this cell/module (Default: 1)
"""

def __init__(self, name: str,
filter_type: str,
area_shape: Tuple[int, int],
n_cells: int,
patch_shape: Tuple[int, int],
step_shape: Tuple[int, int],
batch_size: int = 1,
sigma: float = 1.0,
key: Union[jax.Array, None] = None,
**kwargs):
def __init__(
self,
name: str,
filter_type: str,
area_shape: Tuple[int, int],
n_cells: int,
patch_shape: Tuple[int, int],
step_shape: Tuple[int, int],
batch_size: int = 1,
sigma: float = 1.0,
key: Union[jax.Array, None] = None,
**kwargs
):
super().__init__(name=name, key=key)


## Layer Size Setup
self.filter_type = filter_type
self.n_cells = n_cells
Expand Down Expand Up @@ -143,14 +148,14 @@ def __init__(self, name: str,
@compilable
def advance_state(self, t):
inputs = self.inputs.get()
filter = self.filter.get()
_filter = self.filter.get()
px, py = self.patch_shape

# ═══════════════════ extract pathches for filters ══════════════════
input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape)

# ═══════════════════ apply filter to all pathches ══════════════════
filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py)

# ════════════ reshape all cells responses to a single input to brain ════════════
filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py)
Expand Down Expand Up @@ -184,7 +189,7 @@ def reset(self): ## reset core components/statistics
@classmethod
def help(cls): ## component help function
properties = {
"cell_type": "RetinalGanglionCell - filters the input stimuli, "
"cell_type": "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics"
}
compartment_props = {
"inputs":
Expand All @@ -196,11 +201,11 @@ def help(cls): ## component help function
}
hyperparams = {
"filter_type": "Type of the filter for preprocessing the input",
"sigma": "Standard deviation of gaussian kernel",
"sigma": "Standard deviation of gaussian kernel/filter",
"area_shape": "Effective receptive field area shape of ganglion cells in this module",
"n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer",
"patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module",
"step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module",
"n_cells": "Number of retinal ganglion (center-surround) cells to model in this layer",
"patch_shape": "Classical receptive field area shape of individual ganglion cells in this module",
"step_shape": "Extra-classical receptive field area shape each ganglion cell in this module",
"batch_size": "Batch size dimension of this component"
}
info = {cls.__name__: properties,
Expand All @@ -212,13 +217,15 @@ def help(cls): ## component help function
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
X = RetinalGanglionCell("RGC", filter_type="gaussian",
sigma=2.3,
area_shape=(16, 26),
n_cells = 3,
patch_shape=(16, 16),
step_shape=(0, 5)
)
X = RetinalGanglionCell(
"RGC",
filter_type="gaussian",
sigma=2.3,
area_shape=(16, 26),
n_cells = 3,
patch_shape=(16, 16),
step_shape=(0, 5)
)
print(X)


Expand Down
4 changes: 4 additions & 0 deletions ngclearn/utils/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=Fals
this matrix is a non-negative vector.

Formally, this means we compute, per i-th row:

| rho(x_i) = num_zeros(x_i) / dim(x_i)

and for a global score for matrix X with N codes/rows, we measure:

| rho_mean(X) = 1/N Sum^N_{i=1} rho(x_i)

where lower/closer to 0 means codes more sparse and closer to 1 means
codes are more dense.

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ matplotlib>=3.9.4
# patchify # patchify has issues with pip installation
jax>=0.4.28
jaxlib>=0.4.28
ngcsimlib>=3.0.0
ngcsimlib>=3.1.0
imageio>=2.37.0
pandas>=2.2.3
typing_extensions>=4.15.0
Loading