Skip to content

Remove activation on final layer of KeyNet and ICNN#702

Merged
marcocuturi merged 2 commits into
mainfrom
fix/keynet-icnn-linear-final-layer
Jun 13, 2026
Merged

Remove activation on final layer of KeyNet and ICNN#702
marcocuturi merged 2 commits into
mainfrom
fix/keynet-icnn-linear-final-layer

Conversation

@marcocuturi

@marcocuturi marcocuturi commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

What

The output layer of both KeyNet.gradient and ICNN.__call__ previously applied the activation function (default jax.nn.relu) after the final layer. This is unusual and unwanted:

  • KeyNet: the predicted vector (interpreted as a gradient / key) was forced to be non-negative, so it could never represent signed displacements.
  • ICNN: the scalar potential was clamped to be non-negative.

Change

Make the final layer linear in both networks — activation is applied to every layer except the last (if i != num_layers - 1).

Convexity of the ICNN output is preserved: the final PositiveDense layer is a non-negatively weighted combination (plus bias) of the convex hidden features, which remains convex. The existing convexity (Jensen) and Hessian-PSD tests in tests/neural/networks/icnn_test.py test the convexity gap rather than output non-negativity, so they continue to hold.

🤖 Generated with Claude Code

marcocuturi and others added 2 commits June 13, 2026 20:41
The output layer of both KeyNet.gradient and ICNN.__call__ previously
applied the activation function (default ReLU) after the final layer.
This forced the outputs to be non-negative: KeyNet's predicted vectors
could not take signed values, and ICNN's scalar potential was clamped to
be non-negative.

Make the final layer linear in both networks. Convexity of the ICNN
output is preserved (a non-negatively weighted combination of convex
features remains convex).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Reformat the final-layer loop in ICNN/KeyNet to satisfy yapf (the CI
"code" Lint check). Switch conditional_monge_gap_test to LinenPotentialMLP
so it uses the linen init/apply API it was written for, matching
monge_gap_test (the nnx PotentialMLP now requires input_dim/rngs).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@marcocuturi marcocuturi merged commit b79b849 into main Jun 13, 2026
9 checks passed
@codecov

codecov Bot commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 86.97%. Comparing base (f890467) to head (dc862ba).
⚠️ Report is 4 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #702      +/-   ##
==========================================
+ Coverage   86.90%   86.97%   +0.06%     
==========================================
  Files          83       85       +2     
  Lines        8670     8888     +218     
  Branches      596      616      +20     
==========================================
+ Hits         7535     7730     +195     
- Misses        983      998      +15     
- Partials      152      160       +8     
Files with missing lines Coverage Δ
src/ott/neural/networks/icnn.py 91.37% <100.00%> (+1.90%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

marcocuturi added a commit that referenced this pull request Jun 14, 2026
* docs: fix ICNN architecture block (linear final layer + optional bias)

Follow-up to #702. The ICNN docstring "Architecture" block still showed
the activation applied at every layer and omitted bias terms entirely:

- show the final layer is linear (no activation, per #702) and note that
  convexity is still preserved;
- include the optional per-layer bias (gated by `use_bias`), and clarify
  that the W_x input-injection terms are always bias-free.

Docstring-only change; no behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* docs: add KeyNet architecture block, unify math, cite olausson:26

- Give KeyNet an `Architecture::` block in the same style as ICNN, showing
  the linear final layer and optional bias.
- Unify formula rendering across both docstrings: inline expressions (the
  function signature, the inner-product potential, the residual output) now
  use `:math:` LaTeX consistently instead of a mix of plain text and inline
  code. Docstrings are now raw strings so LaTeX backslashes are literal.
- Add the KeyNet reference (Olausson et al., 2026, "Amortizing Maximum
  Inner Product Search with Learned Support Functions") to references.bib
  and cite it from the KeyNet docstring.

Docstring/bib-only change; no behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant