Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/neat/actions/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,6 @@ def _compile_nest(
):
from pynestml.frontend.pynestml_frontend import generate_nest_compartmental_target

print("!!! codegen_opts in _compile_nest:", codegen_opts)

# assert that `model_name` is a pure name
assert not "/" in model_name
assert not "." in model_name
Expand Down Expand Up @@ -538,6 +536,21 @@ def _install_models(
code uses a fast polynomial approximation only for dynamic propagator
``exp()`` terms in hot loops; all other exponentials use
``std::exp``/``std::expf``.
- ``single_precision_propagator_exp_mode``: ``"bounded"`` or
``"plain"`` (default: ``"bounded"``). Only used when
``fp_precision="single"`` and ``use_fastexp=False``. Selects
bounded or raw ``std::expf`` evaluation for propagator exponentials.
- ``with_profiling``: bool (default: ``False``). If ``True``, generated
models expose cumulative profiling recordables for matrix assembly,
Hines solves, and current evaluation.
- ``with_detailed_recordables``: bool (default: ``False``). If
``True``, generated models expose additional multimeter recordables
for runtime propagators and pure helper functions such as ``*_inf_*``
and ``tau_*`` values.
- ``freeze_exp_mode``: ``"none"`` or ``"freeze_init"`` (default:
``"none"``). ``"freeze_init"`` replaces runtime ``exp`` evaluations
in ``f_numstep()`` with values computed once in ``pre_run_hook()``
from the initialized model state and timestep.
"""
model_name = _resolve_model_name(model_name, channel_path_arg)

Expand Down
9 changes: 8 additions & 1 deletion src/neat/factorydefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy.typing as npt

from dataclasses import dataclass, field
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple


@dataclass
Expand Down Expand Up @@ -133,6 +133,13 @@ class FitParams:
# ]) # ms
# )

# admittance-kernel correction: minimal degree of the rational fit
min_degree: int = 4
# admittance-kernel correction: maximal degree of the rational fit
max_degree: int = 40
# admittance-kernel correction: maximum relative error of the fit
max_rel_error: float = 1e-4


@dataclass
class MechParams:
Expand Down
281 changes: 262 additions & 19 deletions src/neat/modelreduction/compartmentfitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..trees.stree import STree
from ..trees.phystree import PhysTree
from ..trees.compartmenttree import CompartmentTree
from ..tools.kernelextraction import Kernel
from ..tools.kernelextraction import Kernel, fExpFitter, FourierTools
from ..channels.ionchannels import SPDict
from ..factorydefaults import FitParams, MechParams
from .cachetrees import CachedGreensTree, CachedSOVTree, EquilibriumTree
Expand Down Expand Up @@ -595,6 +595,47 @@ def fit_concentration(self, fit_arg, ion):

return ctree, locs

def _get_passified_greenstree(self, suffix="_passified_"):
"""
Construct a `CachedGreensTree` whose membrane has been linearized
("passified") around the equilibrium potentials of the full tree.

This helper is used by `fit_passive` (when
`use_all_channels=True`). The
resulting tree is cache-backed using `self.cache_path` /
`self.cache_name + "_eq" + suffix` and `... + "_gf" + suffix`, so
repeated calls reuse cached evaluations.

Parameters
----------
suffix: str
Cache name suffix. Defaults to ``"_passified_"``, matching the
existing convention in `fit_passive`.

Returns
-------
`neat.CachedGreensTree`
"""
fit_tree = EquilibriumTree(self)
fit_tree.set_cache_params(
cache_path=self.cache_path,
cache_name=self.cache_name + "_eq" + suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
# set the channels to passive
fit_tree.as_passive_membrane()
# convert to a greens tree for further evaluation
fit_tree = CachedGreensTree(
fit_tree,
cache_path=self.cache_path,
cache_name=self.cache_name + "_gf" + suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
fit_tree.set_comp_tree(eps=self.fit_cfg.fit_comptree_eps)
return fit_tree

def fit_passive(self, fit_arg, use_all_channels=True, pprint=False):
"""
Fit the steady state passive model, consisting only of leak and coupling
Expand Down Expand Up @@ -626,24 +667,7 @@ def fit_passive(self, fit_arg, use_all_channels=True, pprint=False):
suffix = f"_passified_"

if use_all_channels:
fit_tree = EquilibriumTree(self)
fit_tree.set_cache_params(
cache_path=self.cache_path,
cache_name=self.cache_name + "_eq" + suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
# set the channels to passive
fit_tree.as_passive_membrane()
# convert to a greens tree for further evaluation
fit_tree = CachedGreensTree(
fit_tree,
cache_path=self.cache_path,
cache_name=self.cache_name + "_gf" + suffix,
save_cache=self.save_cache,
recompute_cache=self.recompute_cache,
)
fit_tree.set_comp_tree(eps=self.fit_cfg.fit_comptree_eps)
fit_tree = self._get_passified_greenstree(suffix=suffix)
else:
fit_tree = self.create_tree_gf(
[], # empty list of channel to include
Expand Down Expand Up @@ -1210,12 +1234,214 @@ def fit_e_eq(self, fit_arg):

return ctree, locs

def compute_admittance_correction(self, fit_arg, kernel_correction, pprint=False):
"""
Add admittance-kernel-correcting dummy compartments to a set of host
compartments in the reduced model.

For each host compartment ``p`` specified by ``kernel_correction``,
the missing admittance
:math:`\Delta Y_p(s) = Y_p^{\\mathrm{full}}(s) - Y_p^{\\mathrm{red}}(s)`
is computed by comparing the driving-point admittance of the
passified full model against the leak-only driving-point admittance
of the reduced model on a log-spaced frequency grid spanning
``self.fit_cfg.freq_band`` Hz. The residual is then fit by a sum of
zero-DC high-pass rational terms

.. math::

\\Delta Y_p(s) \\approx \\sum_q \\alpha_q \\frac{s}{s + b_q}

with ``alpha_q > 0`` and ``b_q > 0`` (see ``min_degree``,
``max_degree`` and ``max_rel_error`` on ``FitParams``). Each
accepted term yields one passive dummy compartment with
``g_c = alpha_q``, ``c = alpha_q / b_q``, ``g_l = 0``, attached
as a leaf to the host node. ``locs`` is extended with ``None``
entries to keep its length in sync with the number of tree nodes.

If for a given host the residual is negligible relative to
``Y_p^full``, or if the fit is non-physical and degrades accuracy,
no dummy compartments are added. If the fit does not reach the
configured tolerance but improves accuracy over the uncorrected
reduction, it is kept and a warning is issued.

Parameters
----------
fit_arg: see docstring of `CompartmentFitter.convert_fit_args()`
Specifying the fit that is being performed.
kernel_correction: list of int
Indices into the list of fit locations selecting the host
compartments to which an admittance correction is applied.
pprint: bool

Returns
-------
`neat.CompartmentTree`
The compartmenttree, with dummy compartments attached.
list of <neat.MorphLoc>
The corresponding list of fit locations, extended with
``None`` for each added dummy compartment.
"""
ctree, locs = self.convert_fit_arg(fit_arg)

if kernel_correction is None or len(kernel_correction) == 0:
return ctree, locs

# validate host indices
n_locs = len(locs)
for p in kernel_correction:
if not (0 <= p < n_locs):
raise IndexError(
f"`kernel_correction` index {p} is out of bounds for "
f"a fit with {n_locs} locations."
)

# frequency grid for admittance evaluation and fitting
ft = FourierTools(
np.array([0.0, 0.1])
) # dummy time array, we don't need it here
s_arr = ft.freqs_vfit

# input impedances full model
gtree = self.create_tree_gf(
[], # empty list of channel to include
cache_name_suffix="_pas_",
)
gtree.set_impedances_in_tree(freqs=s_arr, pprint=pprint)
zs_full = [gtree.calc_zf(loc, loc) for loc in locs]

# to be corrected input impedances reduced model
z_mat = ctree.calc_impedance_matrix(
freqs=s_arr, channel_names=["L"], indexing="locs"
)
zs_red = [z_mat[:, p, p] for p in kernel_correction]

# negligibility threshold: an order of magnitude tighter than
# `max_rel_error`, so we only skip when the residual is truly tiny.
eps_skip = 0.1 * self.fit_cfg.max_rel_error

# use the existing rational-fit engine
fef = fExpFitter()

for ii, loc_idx in enumerate(kernel_correction):
host = ctree.get_nodes_from_loc_idxs(loc_idx)
y_full_p = 1.0 / zs_full[ii]
y_red_p = 1.0 / zs_red[ii]
dy_p = y_full_p - y_red_p

# skip if the residual is negligible relative to the full admittance
if np.max(np.abs(dy_p) / np.abs(y_full_p)) < eps_skip:
continue

# somatic capacitance correction
idx_bool = np.abs(s_arr) > 1e3
(dca,), _, _, _ = np.linalg.lstsq(
s_arr[idx_bool][:, None].imag, dy_p[idx_bool].imag, rcond=None
)
if pprint:
print(f"somatic capcitance correction dca = {dca} uF")
# apply capacitance correction to the host compartment and the residual
host.ca += dca
dy_p -= dca * s_arr

# pl.figure("impedance full/red")
# ax1, ax2, ax3 = pl.subplot(311), pl.subplot(312), pl.subplot(313)
# ax1.plot(s_arr.imag, np.abs(zs_full[ii]), 'b', label="full")
# ax1.plot(s_arr.imag, np.abs(zs_red[ii]), 'r--', label="red")
# ax1.legend(loc=0)
# ax2.plot(s_arr.imag, (dy_p).real, 'g', label="full real")
# ax2.plot(s_arr.imag, (dy_p).imag, 'g--', label="full imag")
# ax2.plot(s_arr.imag, np.abs(dy_p), 'g:', label="full abs")
# ax2.plot(s_arr.imag, (dca * s_arr).real, 'y', label="fit real")
# ax2.plot(s_arr.imag, (dca * s_arr).imag, 'y--', label="fit imag")
# ax2.plot(s_arr.imag, np.abs(dca * s_arr), 'y:', label="fit abs")

# ax3.plot(s_arr.imag, (dy_p - (dca * s_arr)).real, 'm', label="residual real")
# ax3.plot(s_arr.imag, (dy_p - (dca * s_arr)).imag, 'm--', label="residual imag")
# ax3.plot(s_arr.imag, np.abs(dy_p - (dca * s_arr)), 'm:', label="residual abs")
# pl.show()

for Q in range(self.fit_cfg.min_degree, self.fit_cfg.max_degree + 1, 4):
alphas, gammas, _, rms = fef.fitFExp(
s_arr,
dy_p,
deg=Q,
rtol=1e-4,
realpoles=True,
initpoles="log10",
zerostart=False,
constrained=True,
reduce_numexp=False,
return_real=True,
)
if pprint:
print(
f">>> admittance correction for host loc {loc_idx} and degree {Q}: rms = {rms}"
)

if rms < self.fit_cfg.max_rel_error:
print(
f"Error criterion satisfied for Q = {Q} | {rms} < {self.fit_cfg.max_rel_error}, stopping fit."
)
print("alpha:\n", alphas, "\ngamma:\n", gammas)
break

# f_corr = Kernel((alphas*1e-3, gammas*1e-3))
# pl.figure(f"fit Q={Q}")
# pl.plot(s_arr.imag, dy_p.real, 'g', label="target real")
# pl.plot(s_arr.imag, dy_p.imag, 'g--', label="target real")
# pl.plot(s_arr.imag, np.abs(dy_p), 'g:', label="target real")
# pl.plot(s_arr.imag, f_corr.ft(s_arr).real + bias, 'y', label="fit real")
# pl.plot(s_arr.imag, f_corr.ft(s_arr).imag + bias, 'y--', label="fit imag")
# pl.plot(s_arr.imag, np.abs(f_corr.ft(s_arr)) + bias, 'y:', label="fit abs")
# pl.legend(loc=0)
# pl.show()

if pprint:
print("Original compartment tree:\n", ctree)

# attach dummy compartments to the host nodes according to fit results
for ar, gr in zip(alphas, gammas):
b_q = float(ar.real)
gc_q = float((-gr / ar).real)
ca_q = gc_q / b_q

dummy_idx = max(n.index for n in ctree) + 1
dummy = ctree.create_corresponding_node(
dummy_idx,
ca=ca_q,
g_c=gc_q,
g_l=0.0,
)
# dummy has no associated MorphLoc
dummy.loc_idx = None
dummy.e_eq = host.e_eq
# leak reversal of the dummy follows the host leak
# reversal. The dummy has
# `g_l = 0`, so this is dynamically inert but keeps the
# convention explicit.
if "L" in host.currents:
dummy.currents["L"] = (0.0, host.currents["L"][1])
else:
dummy.currents["L"] = (0.0, host.e_eq)
ctree.add_node_with_parent(dummy, host)
locs.append(None)

if pprint:
print(
f"\nNew compartment tree after correcting compartment {host.index}:\n",
ctree,
)

return ctree, locs

def fit_model(
self,
loc_arg,
fit_name="",
alpha_inds=[0],
use_all_channels_for_passive=True,
kernel_correction=None,
pprint=False,
):
"""
Expand All @@ -1234,6 +1460,15 @@ def fit_model(
Indices of all mode time-scales to be included in the fit
use_all_channels_for_passive: bool (optional, default ``True``)
Uses all channels in the tree to compute coupling conductances
kernel_correction: list of int or ``None`` (optional, default ``None``)
Indices into ``loc_arg`` (after bifurcation extension) selecting
host compartments to which an admittance-kernel correction will be
applied via additional passive dummy compartments. ``None`` or an
empty list disables the correction (default behavior). Note that this
correction can result in negative couplings / capacitances of dummy
compartments, which NEURON / Brian 2 exports do not support.
`NeuronCompartmentTree` and `Brian2CompartmentTree` will raise a warning
and ignore the correction.
pprint: bool
whether to print information

Expand Down Expand Up @@ -1273,6 +1508,14 @@ def fit_model(
# fit the resting potentials
fit_arg = self.fit_e_eq(fit_arg)

# admittance-kernel correction by dummy compartments
if kernel_correction:
fit_arg = self.compute_admittance_correction(
fit_arg,
kernel_correction,
pprint=pprint,
)

if fit_name == "temp":
self.remove_fit(fit_name)
else:
Expand Down
Loading
Loading