Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainscore_vision/benchmarks/allen2022_fmri/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _Allen2022fmri(region,
def Allen2022fmri(region: str, metric_type: str,
dataset_prefix: str = 'Allen2022_fmri',
alphas: list = ALPHA_LIST):
similarity_metric = load_metric(f'{metric_type}_split', alphas=alphas)
similarity_metric = load_metric(f'dual_{metric_type}_split', alphas=alphas)
return _Allen2022fmri(region, similarity_metric=similarity_metric,
identifier_metric_suffix=metric_type,
dataset_prefix=dataset_prefix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _Allen2022fmriSurface(region,
def Allen2022fmriSurface(region: str, metric_type: str,
dataset_prefix: str = 'Allen2022_fmri_surface',
alphas: list = ALPHA_LIST):
similarity_metric = load_metric(f'{metric_type}_split', alphas=alphas)
similarity_metric = load_metric(f'dual_{metric_type}_split', alphas=alphas)
return _Allen2022fmriSurface(region, similarity_metric=similarity_metric,
identifier_metric_suffix=metric_type,
dataset_prefix=dataset_prefix,
Expand Down
7 changes: 6 additions & 1 deletion brainscore_vision/metrics/regression_correlation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from brainscore_vision import metric_registry
from .metric import CrossRegressedCorrelation, pls_regression, ridge_cv_regression, ridge_regression, single_regression, linear_regression,\
from .metric import CrossRegressedCorrelation, pls_regression, ridge_cv_regression, ridge_regression, \
dual_ridge_regression, dual_ridge_cv_regression, single_regression, linear_regression,\
pearsonr_correlation, ReverseCrossRegressedCorrelation, ReverseTrainTestSplitCorrelation


Expand All @@ -26,6 +27,10 @@
regression=linear_regression(), correlation=pearsonr_correlation(), *args, **kwargs)
metric_registry['ridgecv_split'] = lambda *args, **kwargs: TrainTestSplitCorrelation(
regression=ridge_cv_regression(**kwargs), correlation=pearsonr_correlation(), *args, **kwargs)
metric_registry['dual_ridge_split'] = lambda *args, **kwargs: TrainTestSplitCorrelation(
regression=dual_ridge_regression(), correlation=pearsonr_correlation(), *args, **kwargs)
metric_registry['dual_ridgecv_split'] = lambda *args, **kwargs: TrainTestSplitCorrelation(
regression=dual_ridge_cv_regression(**kwargs), correlation=pearsonr_correlation(), *args, **kwargs)

metric_registry["reverse_pls_cv"] = lambda *args, **kwargs: ReverseCrossRegressedCorrelation(
regression=pls_regression(), correlation=pearsonr_correlation(), *args, **kwargs)
Expand Down
272 changes: 270 additions & 2 deletions brainscore_vision/metrics/regression_correlation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def apply(self, source_train, target_train, source_test, target_test):
prediction = self.regression.predict(source_test)
score = self.correlation(prediction, target_test)

if self.regression._regression.__class__ in [RidgeCV]:
if hasattr(self.regression._regression, 'alpha_'):
score.attrs['alpha'] = self.regression._regression.alpha_

return score
Expand All @@ -128,6 +128,254 @@ def __call__(self, *, source_train, target_train, source_test, target_test):
target_test=source_test,
)

class DualRidgeRegression:
"""Ridge regression using dual (kernel) form for memory efficiency.

When n_samples < n_features, uses adaptive storage to minimize memory
after fit:
- If n_targets < n_samples: computes coef_ and frees X_train (primal-style
predict, but without sklearn's float64 copy)
- If n_targets >= n_samples: keeps X_train and predicts via kernel projection
(avoids materializing the large coef_ matrix)

Falls back to sklearn Ridge when n_samples >= n_features.
Mathematically identical to sklearn Ridge with fit_intercept=True.
"""

def __init__(self, alpha: float = 1.0, chunk_size: int = 5000):
self.alpha = alpha
self.chunk_size = chunk_size

def fit(self, X, Y) -> None:
X = np.asarray(X, dtype=np.float32)
Y = np.asarray(Y, dtype=np.float32)
n_samples, n_features = X.shape
n_targets = Y.shape[1]

if n_samples >= n_features:
self._use_dual = False
self._primal = Ridge(alpha=self.alpha)
self._primal.fit(X, Y)
return

self._use_dual = True
self._X_mean = X.mean(axis=0)
self._Y_mean = Y.mean(axis=0)
X_c = X - self._X_mean
Y_c = Y - self._Y_mean

# Compute kernel and solve in float64 for numerical stability
K = np.float64(X_c @ X_c.T)
K[np.diag_indices_from(K)] += self.alpha
K_inv = np.float32(np.linalg.solve(K, np.eye(n_samples)))

if n_targets < n_samples:
# coef_ is smaller than X_train — compute it, free X
dual_coef = K_inv @ Y_c
self._coef = X_c.T @ dual_coef
self._use_coef = True
else:
# X_train is smaller than coef_ — keep it for projection
self._X_train_centered = X_c
self._Y_train_centered = Y_c
self._K_inv = K_inv
self._use_coef = False

def predict(self, X) -> np.ndarray:
if not self._use_dual:
return self._primal.predict(X)

X = np.asarray(X, dtype=np.float32)
X_test_c = X - self._X_mean

if self._use_coef:
return X_test_c @ self._coef + self._Y_mean

proj = X_test_c @ self._X_train_centered.T @ self._K_inv

n_test = X.shape[0]
n_targets = self._Y_train_centered.shape[1]
predictions = np.empty((n_test, n_targets), dtype=np.float32)
for i in range(0, n_targets, self.chunk_size):
end = min(i + self.chunk_size, n_targets)
predictions[:, i:end] = proj @ self._Y_train_centered[:, i:end] + self._Y_mean[i:end]
return predictions


class DualRidgeCVRegression:
"""RidgeCV with dual form for memory efficiency.

When n_samples < n_features and no custom scoring/cv is requested,
selects alpha via LOO cross-validation in kernel space using the
eigendecomposition of K = X @ X.T, then predicts via dual-form
projection. Never materializes the (n_features, n_targets) coef_ matrix.

When custom scoring or cv is requested, falls back to sklearn RidgeCV
for alpha selection (preserving all sklearn behavior), then uses
DualRidgeRegression for prediction.

Falls back to sklearn RidgeCV entirely when n_samples >= n_features.

Exposes ``alpha_`` after fit (selected regularization strength).
"""

def __init__(self, alphas=None, chunk_size: int = 5000, **ridgecv_kwargs):
self.alphas = alphas
self.chunk_size = chunk_size
self._ridgecv_kwargs = ridgecv_kwargs
self.alpha_ = None

def _can_use_dual_loo(self) -> bool:
"""Check if we can do alpha selection in kernel space.

Dual LOO is only valid when sklearn would use its efficient LOO path:
no custom scoring function, no explicit cv folds, no per-target alpha.
"""
if self._ridgecv_kwargs.get('scoring') is not None:
return False
if self._ridgecv_kwargs.get('cv') is not None:
return False
if self._ridgecv_kwargs.get('alpha_per_target', False):
return False
return True

def fit(self, X, Y) -> None:
X = np.asarray(X, dtype=np.float32)
Y = np.asarray(Y, dtype=np.float32)
n_samples, n_features = X.shape

if n_samples >= n_features:
self._use_dual = False
kwargs = dict(self._ridgecv_kwargs)
if self.alphas is not None:
kwargs['alphas'] = self.alphas
self._primal = RidgeCV(**kwargs)
self._primal.fit(X, Y)
self.alpha_ = self._primal.alpha_
return

self._use_dual = True

if self._can_use_dual_loo():
self._fit_dual_loo(X, Y, n_samples)
else:
self._fit_sklearn_then_dual(X, Y)

def _fit_dual_loo(self, X, Y, n_samples) -> None:
"""Select alpha via LOO in kernel space. No coef_ materialized.

Replicates sklearn's _RidgeGCV eigen decomposition approach:
center X, add intercept to kernel via outer product, eigendecompose,
zero regularization on the intercept eigenvector, then evaluate LOO
for each alpha candidate.

Data stored in float32 to halve memory. Kernel eigendecomposition and
LOO scoring done in float64 for numerical precision.
"""
# Center
self._X_mean = X.mean(axis=0)
self._Y_mean = Y.mean(axis=0)
X_c = X - self._X_mean
Y_c = Y - self._Y_mean

# Kernel with intercept in float64 for eigendecomposition precision
K = np.float64(X_c @ X_c.T)
K += 1.0 # equivalent to np.ones((n,n)) but avoids allocation

eigenvalues, Q = np.linalg.eigh(K)
QT_y = Q.T @ np.float64(Y) # project UN-centered Y in float64

# Find the intercept eigenvector (most aligned with ones vector)
normalized_sw = np.ones(n_samples) / np.sqrt(n_samples)
intercept_dim = np.argmax(np.abs(Q.T @ normalized_sw))

# Evaluate LOO for each alpha (all float64 — small matrices)
alphas = self.alphas if self.alphas is not None else [0.1, 1.0, 10.0]
best_alpha = alphas[0]
best_score = -np.inf

Q_sq = Q ** 2

for alpha in alphas:
w = 1.0 / (eigenvalues + alpha)
w[intercept_dim] = 0 # no regularization on intercept

c = Q @ (w[:, None] * QT_y)
G_inv_diag = Q_sq @ w
G_inv_diag = np.maximum(G_inv_diag, 1e-12)

loo_errors = c / G_inv_diag[:, None]
score = -np.mean(loo_errors ** 2) # negative MSE (higher is better)

if score > best_score:
best_score = score
best_alpha = alpha

self.alpha_ = best_alpha

# Compute K_inv in float64, store as float32
K_pred = np.float64(X_c @ X_c.T)
K_pred[np.diag_indices_from(K_pred)] += self.alpha_
K_inv = np.float32(np.linalg.solve(K_pred, np.eye(n_samples)))

# Adaptive storage: keep whichever is smaller after fit
n_targets = Y_c.shape[1]
if n_targets < n_samples:
# coef_ is smaller than X_train — compute it, free X
dual_coef = K_inv @ Y_c
self._coef = X_c.T @ dual_coef
self._use_coef = True
else:
# X_train is smaller than coef_ — keep it for projection
self._X_train_centered = X_c
self._Y_train_centered = Y_c
self._K_inv = K_inv
self._use_coef = False

def _fit_sklearn_then_dual(self, X, Y) -> None:
"""Fallback: sklearn RidgeCV for alpha, DualRidge for prediction.

Used when custom scoring/cv/alpha_per_target prevents dual LOO.
"""
kwargs = dict(self._ridgecv_kwargs)
if self.alphas is not None:
kwargs['alphas'] = self.alphas
rcv = RidgeCV(**kwargs)
rcv.fit(X, Y)
self.alpha_ = rcv.alpha_
del rcv

self._dual = DualRidgeRegression(
alpha=float(self.alpha_), chunk_size=self.chunk_size
)
self._dual.fit(X, Y)

def predict(self, X) -> np.ndarray:
if not self._use_dual:
return self._primal.predict(X)

if hasattr(self, '_dual'):
return self._dual.predict(X)

X = np.asarray(X, dtype=np.float32)
X_test_c = X - self._X_mean

if self._use_coef:
return X_test_c @ self._coef + self._Y_mean

proj = X_test_c @ self._X_train_centered.T @ self._K_inv

n_test = X.shape[0]
n_targets = self._Y_train_centered.shape[1]
predictions = np.empty((n_test, n_targets), dtype=np.float32)
for i in range(0, n_targets, self.chunk_size):
end = min(i + self.chunk_size, n_targets)
predictions[:, i:end] = (
proj @ self._Y_train_centered[:, i:end] + self._Y_mean[i:end]
)
return predictions


def pls_regression(regression_kwargs=None, xarray_kwargs=None):
regression_defaults = dict(n_components=25, scale=False)
regression_kwargs = {**regression_defaults, **(regression_kwargs or {})}
Expand Down Expand Up @@ -165,12 +413,32 @@ def ridge_cv_regression(regression_kwargs=None, xarray_kwargs=None, alphas=ALPHA
regression_defaults = dict(alphas=alphas, store_cv_results=False)
regression_kwargs = {**regression_defaults, **(regression_kwargs or {})}
regression_kwargs.pop('alpha', None) # RidgeCV does not accept 'alpha' as a parameter

regression = RidgeCV(**regression_kwargs)
xarray_kwargs = xarray_kwargs or {}
regression = XarrayRegression(regression, **xarray_kwargs)
return regression


def dual_ridge_regression(regression_kwargs=None, xarray_kwargs=None):
regression_defaults = dict(alpha=1)
regression_kwargs = {**regression_defaults, **(regression_kwargs or {})}
regression = DualRidgeRegression(**regression_kwargs)
xarray_kwargs = xarray_kwargs or {}
regression = XarrayRegression(regression, **xarray_kwargs)
return regression


def dual_ridge_cv_regression(regression_kwargs=None, xarray_kwargs=None, alphas=ALPHA_LIST):
regression_defaults = dict(alphas=alphas)
regression_kwargs = {**regression_defaults, **(regression_kwargs or {})}
regression_kwargs.pop('alpha', None)

regression = DualRidgeCVRegression(**regression_kwargs)
xarray_kwargs = xarray_kwargs or {}
regression = XarrayRegression(regression, **xarray_kwargs)
return regression

def single_regression(xarray_kwargs=None):
regression = SingleRegression()
xarray_kwargs = xarray_kwargs or {}
Expand Down
Loading