Skip to content

Unconstrain Wishart#8246

Draft
ricardoV94 wants to merge 13 commits intopymc-devs:v6from
ricardoV94:Wishart
Draft

Unconstrain Wishart#8246
ricardoV94 wants to merge 13 commits intopymc-devs:v6from
ricardoV94:Wishart

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 9, 2026

Contains commits from #8243

Similar idea as #7380 (but this is actually simpler). Almost the same rewrite as LKJCholeskyCov, except here we unconstrain to the full dense matrix.

ricardoV94 and others added 12 commits April 9, 2026 10:09
The public transposition attribute on DimShuffle was renamed to
_transposition in PyTensor v3, but the dropped input dimensions are
already exposed directly via DimShuffle.drop.
PyTensor v3 no longer carries the device flag (it dropped GPU support
via the legacy device knob), so the skip marker now errors out at
collection time. The test runs fine on CPU which is the only supported
backend now.
The pytensor.compile.function submodule was removed in PyTensor v3;
the function is still re-exported from the top-level pytensor namespace.
The trace setup compiled function uses trust_input=True, which under
PyTensor v3 strictly requires the storage to hold an ndarray for 0-d
inputs. The ADVI init paths in init_nuts produce per-chain initial
points by indexing the variational MultiTrace, which yields numpy
scalars rather than 0-d ndarrays for scalar RVs. Wrap with np.asarray
when constructing initial_points and likewise when bootstrapping the
NDArray trace inside Approximation.sample.

The BaseTrace.point return type is loosened to dict[str, Any] to make
this contract explicit -- callers needing strict ndarrays must wrap.
PyTensor v3 emits a DeprecationWarning from RandomVariable.__call__
unless the caller passes return_next_rng=True. Distribution.dist now
always invokes the underlying op via that opt-in API (so the warning
never fires from this codepath) and exposes a new return_next_rng
parameter, defaulting to False, that gives callers explicit access
to the (next_rng, rv) tuple. This replaces the awkward
.owner.outputs / .make_node().outputs dance used internally to grab
the next-rng output from RV calls.
Pass return_next_rng=True to every RandomVariable / XRV call that
PyTensor v3 emits a "RandomVariable Ops will stop hiding the rng
output" deprecation for. Two flavours of call site are covered:

* Sites that already wanted the next-rng output and were doing the
  awkward .owner.outputs unpacking are rewritten to take the tuple
  directly from the new API.
* Sites that did not capture the next-rng (DimDistribution.dist's
  XRV path, Empirical's integers sampler, the initial-point jitter,
  and the various pt.random.<dist>(...) call sites scattered across
  the distribution implementations) now spell the discard explicitly
  via ``_, rv = ...(..., return_next_rng=True)``.

SymbolicRandomVariable calls (e.g. PrecisionMvNormalRV.rv_op) keep
their .owner.outputs access since the kwarg is specific to plain
RandomVariable / XRV Ops.
PyTensor v3 also emits a 'Calling a RandomVariable without an explicit
rng' DeprecationWarning. This commit threads an explicit rng (a fresh
shared default_rng) through the few entry points where rng was being
left implicit:

* Distribution.dist / DimDistribution.dist (via the _call_rv_op helper
  and the analogous DimDistribution path).
* change_rv_size in shape_utils, which used to rebuild a resized RV
  via rv_node.op(*params, size=new_size) without passing rng.
* The initial-point jitter uniform call.
* CustomDist.rv_op, which now goes through _call_rv_op so the same
  rng-injection logic applies to ad-hoc CustomDistRV instances.
* The dims xrv_op classmethods that wrap a core RV; they now forward
  return_next_rng (and any other kwargs) to the underlying as_xrv
  call so DimDistribution.dist's opt-in to return_next_rng=True works
  for them too. Censored returns (None, rv) when return_next_rng is
  set since it has no rng of its own.
The legacy arviz package's dict_to_dataset and from_dict
have different signatures from the arviz_base ones that PyMC has
been migrating to. Switch the remaining call sites in
`pymc/backends/arviz.py`, `pymc/sampling/mcmc.py`, and the
`tests/stats/test_log_density.py` test fixtures to import from
`arviz_base` directly so the kwargs (`inference_library`,
`sample_dims`) line up.
PyTensor v3 deprecated assignment to SharedVariable.default_update
without offering a replacement: rng updates should now be threaded
through pytensor.function's updates argument or inferred from
the graph by collect_default_updates. Remove every remaining call
site that mutated default_update:

* change_rv_size no longer tries to replicate the old rng's
  default_update on the resized RV.
* The Scan logprob rewriter relies on the inferred updates returned
  by the surrounding construct_scan call instead of mutating each
  inner rng.
* collect_default_updates no longer respects user-provided
  input_rng.default_update.
* The jax fallback no longer rejects shared variables with a
  default_update set (it now only rejects shared RandomTypes).
* The corresponding test paths are dropped, and
  test_change_rv_size_default_update is removed entirely as it was
  exclusively exercising the deprecated mechanism.
@ricardoV94 ricardoV94 changed the title Unconstraint Wishart Unconstrain Wishart Apr 9, 2026
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community bot commented Apr 9, 2026

Documentation build overview

📚 pymc | 🛠️ Build #32194221 | 📁 Comparing c8a1764 against latest (4b0b6ee)

  🔍 Preview build  

Show files changed (216 files in total): 📝 212 modified | ➕ 4 added | ➖ 0 deleted
File Status
genindex.html 📝 modified
glossary.html 📝 modified
api/backends.html 📝 modified
api/logprob.html 📝 modified
contributing/developer_guide.html 📝 modified
_modules/pymc/data.html 📝 modified
_modules/pymc/pytensorf.html 📝 modified
_modules/pymc/testing.html 📝 modified
api/distributions/multivariate.html 📝 modified
api/generated/pymc.backends.zarr.ZarrTrace.html 📝 modified
api/generated/pymc.compute_deterministics.html 📝 modified
api/generated/pymc.compute_log_likelihood.html 📝 modified
api/generated/pymc.compute_log_prior.html 📝 modified
api/generated/pymc.draw.html 📝 modified
api/generated/pymc.icdf.html 📝 modified
api/generated/pymc.logcdf.html 📝 modified
api/generated/pymc.logp.html 📝 modified
api/generated/pymc.ode.DifferentialEquation.html 📝 modified
api/generated/pymc.predictions_to_inference_data.html 📝 modified
api/generated/pymc.sample_posterior_predictive.html 📝 modified
api/generated/pymc.sample_prior_predictive.html 📝 modified
api/generated/pymc.sampling.jax.sample_blackjax_nuts.html 📝 modified
api/generated/pymc.sampling.jax.sample_numpyro_nuts.html 📝 modified
api/generated/pymc.smc.sample_smc.html 📝 modified
api/generated/pymc.to_inference_data.html 📝 modified
learn/core_notebooks/dimensionality.html 📝 modified
learn/core_notebooks/dims_module.html 📝 modified
learn/core_notebooks/model_comparison.html 📝 modified
learn/core_notebooks/posterior_predictive.html 📝 modified
learn/core_notebooks/pymc_overview.html 📝 modified
_modules/pymc/backends/arviz.html 📝 modified
_modules/pymc/backends/base.html 📝 modified
_modules/pymc/backends/ndarray.html 📝 modified
_modules/pymc/backends/zarr.html 📝 modified
_modules/pymc/distributions/continuous.html 📝 modified
_modules/pymc/distributions/custom.html 📝 modified
_modules/pymc/distributions/discrete.html 📝 modified
_modules/pymc/distributions/distribution.html 📝 modified
_modules/pymc/distributions/mixture.html 📝 modified
_modules/pymc/distributions/multivariate.html 📝 modified
_modules/pymc/distributions/shape_utils.html 📝 modified
_modules/pymc/distributions/timeseries.html 📝 modified
_modules/pymc/distributions/transforms.html 📝 modified
_modules/pymc/distributions/truncated.html 📝 modified
_modules/pymc/logprob/basic.html 📝 modified
_modules/pymc/model/core.html 📝 modified
_modules/pymc/model/fgraph.html 📝 modified
_modules/pymc/ode/ode.html 📝 modified
_modules/pymc/sampling/deterministic.html 📝 modified
_modules/pymc/sampling/forward.html 📝 modified
_modules/pymc/sampling/mcmc.html 📝 modified
_modules/pymc/smc/sampling.html 📝 modified
_modules/pymc/stats/log_density.html 📝 modified
_modules/pymc/tuning/starting.html 📝 modified
_modules/pymc/variational/approximations.html 📝 modified
_modules/pymc/variational/operators.html 📝 modified
_modules/pymc/variational/opvi.html 📝 modified
_modules/pytensor/tensor/basic.html 📝 modified
_modules/pytensor/tensor/extra_ops.html 📝 modified
_modules/pytensor/tensor/math.html 📝 modified
_modules/pytensor/tensor/special.html 📝 modified
api/dims/generated/pymc.dims.Beta.html 📝 modified
api/dims/generated/pymc.dims.Cauchy.html 📝 modified
api/dims/generated/pymc.dims.Gamma.html 📝 modified
api/dims/generated/pymc.dims.HalfCauchy.html 📝 modified
api/dims/generated/pymc.dims.HalfNormal.html 📝 modified
api/dims/generated/pymc.dims.HalfStudentT.html 📝 modified
api/dims/generated/pymc.dims.InverseGamma.html 📝 modified
api/dims/generated/pymc.dims.Laplace.html 📝 modified
api/dims/generated/pymc.dims.LogNormal.html 📝 modified
api/dims/generated/pymc.dims.Normal.html 📝 modified
api/dims/generated/pymc.dims.StudentT.html 📝 modified
api/dims/generated/pymc.dims.TruncatedNormal.html 📝 modified
api/dims/generated/pymc.dims.Uniform.html 📝 modified
api/dims/generated/pymc.dims.Weibull.html 📝 modified
api/distributions/classmethods/pymc.Censored.dist.html 📝 modified
api/distributions/classmethods/pymc.Simulator.dist.html 📝 modified
api/distributions/classmethods/pymc.Truncated.dist.html 📝 modified
api/distributions/generated/pymc.Bernoulli.html 📝 modified
api/distributions/generated/pymc.Beta.html 📝 modified
api/distributions/generated/pymc.BetaBinomial.html 📝 modified
api/distributions/generated/pymc.Binomial.html 📝 modified
api/distributions/generated/pymc.Categorical.html 📝 modified
api/distributions/generated/pymc.Cauchy.html 📝 modified
api/distributions/generated/pymc.ChiSquared.html 📝 modified
api/distributions/generated/pymc.Continuous.html 📝 modified
api/distributions/generated/pymc.Discrete.html 📝 modified
api/distributions/generated/pymc.DiscreteUniform.html 📝 modified
api/distributions/generated/pymc.DiscreteWeibull.html 📝 modified
api/distributions/generated/pymc.Distribution.html 📝 modified
api/distributions/generated/pymc.ExGaussian.html 📝 modified
api/distributions/generated/pymc.Exponential.html 📝 modified
api/distributions/generated/pymc.Gamma.html 📝 modified
api/distributions/generated/pymc.Geometric.html 📝 modified
api/distributions/generated/pymc.Gumbel.html 📝 modified
api/distributions/generated/pymc.HalfCauchy.html 📝 modified
api/distributions/generated/pymc.HalfNormal.html 📝 modified
api/distributions/generated/pymc.HalfStudentT.html 📝 modified
api/distributions/generated/pymc.HyperGeometric.html 📝 modified
api/distributions/generated/pymc.Interpolated.html 📝 modified
api/distributions/generated/pymc.InverseGamma.html 📝 modified
api/distributions/generated/pymc.Kumaraswamy.html 📝 modified
api/distributions/generated/pymc.LKJCorr.html 📝 modified
api/distributions/generated/pymc.Laplace.html 📝 modified
api/distributions/generated/pymc.LogNormal.html 📝 modified
api/distributions/generated/pymc.Logistic.html 📝 modified
api/distributions/generated/pymc.LogitNormal.html 📝 modified
api/distributions/generated/pymc.Moyal.html 📝 modified
api/distributions/generated/pymc.NegativeBinomial.html 📝 modified
api/distributions/generated/pymc.Normal.html 📝 modified
api/distributions/generated/pymc.Pareto.html 📝 modified
api/distributions/generated/pymc.Poisson.html 📝 modified
api/distributions/generated/pymc.PolyaGamma.html 📝 modified
api/distributions/generated/pymc.Rice.html 📝 modified
api/distributions/generated/pymc.SkewNormal.html 📝 modified
api/distributions/generated/pymc.StudentT.html 📝 modified
api/distributions/generated/pymc.SymbolicRandomVariable.html 📝 modified
api/distributions/generated/pymc.Triangular.html 📝 modified
api/distributions/generated/pymc.TruncatedNormal.html 📝 modified
api/distributions/generated/pymc.Uniform.html 📝 modified
api/distributions/generated/pymc.VonMises.html 📝 modified
api/distributions/generated/pymc.Wald.html 📝 modified
api/distributions/generated/pymc.Weibull.html 📝 modified
api/distributions/generated/pymc.Wishart.html 📝 modified
api/distributions/generated/pymc.WishartBartlett.html 📝 modified
api/distributions/generated/pymc.ZeroInflatedBinomial.html 📝 modified
api/distributions/generated/pymc.ZeroInflatedNegativeBinomial.html 📝 modified
api/distributions/generated/pymc.ZeroInflatedPoisson.html 📝 modified
api/generated/classmethods/pymc.backends.NDArray.point.html 📝 modified
api/generated/classmethods/pymc.backends.base.BaseTrace.point.html 📝 modified
api/generated/classmethods/pymc.backends.zarr.ZarrChain.point.html 📝 modified
api/generated/classmethods/pymc.backends.zarr.ZarrTrace.to_inferencedata.html 📝 modified
api/generated/classmethods/pymc.ode.DifferentialEquation.L_op.html 📝 modified
api/generated/classmethods/pymc.ode.DifferentialEquation.R_op.html 📝 modified
api/generated/classmethods/pymc.ode.DifferentialEquation.do_constant_folding.html 📝 modified
api/generated/classmethods/pymc.ode.DifferentialEquation.grad.html 📝 modified
api/generated/classmethods/pymc.ode.DifferentialEquation.prepare_node.html 📝 modified
api/generated/classmethods/pymc.ode.DifferentialEquation.pullback.html ➕ added
api/generated/classmethods/pymc.ode.DifferentialEquation.pushforward.html ➕ added
_modules/pymc/dims/distributions/censored.html 📝 modified
_modules/pymc/dims/distributions/scalar.html 📝 modified
_modules/pymc/dims/distributions/vector.html 📝 modified
_modules/pymc/model/transform/optimization.html 📝 modified
api/distributions/generated/classmethods/pymc.AR.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.AsymmetricLaplace.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Bernoulli.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Beta.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.BetaBinomial.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Binomial.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.CAR.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Categorical.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Cauchy.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Continuous.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.DiracDelta.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Dirichlet.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.DirichletMultinomial.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Discrete.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.DiscreteUniform.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.DiscreteWeibull.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Distribution.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.EulerMaruyama.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.ExGaussian.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Exponential.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Flat.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.GARCH11.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Gamma.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Geometric.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Gumbel.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.HalfCauchy.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.HalfFlat.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.HalfNormal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.HalfStudentT.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.HyperGeometric.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.ICAR.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Interpolated.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.InverseGamma.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.KroneckerNormal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Kumaraswamy.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.LKJCorr.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Laplace.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.LogNormal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Logistic.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.MatrixNormal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Mixture.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Moyal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Multinomial.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.MvNormal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.MvStudentT.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.NegativeBinomial.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Normal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Pareto.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Poisson.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.PolyaGamma.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Rice.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.SkewNormal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.SkewStudentT.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.StickBreakingWeights.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.StudentT.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.L_op.html 📝 modified
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.R_op.html 📝 modified
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.do_constant_folding.html 📝 modified
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.grad.html 📝 modified
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.prepare_node.html 📝 modified
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.pullback.html ➕ added
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.pushforward.html ➕ added
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.rebuild_rv.html 📝 modified
api/distributions/generated/classmethods/pymc.SymbolicRandomVariable.update.html 📝 modified
api/distributions/generated/classmethods/pymc.Triangular.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.TruncatedNormal.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Uniform.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.VonMises.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Wald.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Weibull.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.Wishart.dist.html 📝 modified
api/distributions/generated/classmethods/pymc.ZeroSumNormal.dist.html 📝 modified
api/model/generated/classmethods/pymc.model.core.Model.compile_fn.html 📝 modified

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 9, 2026

Codecov Report

❌ Patch coverage is 55.52239% with 149 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (v6@627eb74). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pymc/_pytensor_rewrites.py 0.00% 139 Missing ⚠️
pymc/distributions/multivariate.py 91.17% 6 Missing ⚠️
pymc/logprob/scan.py 33.33% 2 Missing ⚠️
pymc/dims/distributions/censored.py 83.33% 1 Missing ⚠️
pymc/distributions/discrete.py 0.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@          Coverage Diff          @@
##             v6    #8246   +/-   ##
=====================================
  Coverage      ?   84.83%           
=====================================
  Files         ?      125           
  Lines         ?    20133           
  Branches      ?        0           
=====================================
  Hits          ?    17080           
  Misses        ?     3053           
  Partials      ?        0           
Files with missing lines Coverage Δ
pymc/backends/arviz.py 94.94% <100.00%> (ø)
pymc/backends/base.py 88.26% <100.00%> (ø)
pymc/backends/ndarray.py 79.81% <100.00%> (ø)
pymc/data.py 85.07% <100.00%> (ø)
pymc/dims/distributions/core.py 91.70% <100.00%> (ø)
pymc/dims/distributions/scalar.py 96.61% <100.00%> (ø)
pymc/dims/distributions/vector.py 90.00% <100.00%> (ø)
pymc/distributions/continuous.py 98.18% <100.00%> (ø)
pymc/distributions/custom.py 72.17% <100.00%> (ø)
pymc/distributions/distribution.py 92.78% <100.00%> (ø)
... and 26 more
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented Apr 9, 2026

Took a look at the compiled logp+dlogp, and we pay some price for the whole matrix construction.

For an n × n Wishart the unconstrained vector has length n(n+1)/2, with diagonal positions at the cumulative-sum sequence [0, 2, 5, 9, …] of length n.

Full logp+dlogp graph

# logp
Composite{(2.079441547393799 + (0.5 * ((-14.159198660192542 + (2.0 * i2)) - i1)) + i0)} [id A] 17
 ├─ Sum{axes=None} [id B] 5
 │  └─ Mul [id C] 3
 │     ├─ [4. 3. 2.] [id D]
 │     └─ AdvancedSubtensor1 [id E] 0
 │        ├─ Sigma_cholesky-cov__ [id F]
 │        └─ [0 2 5] [id G]
 ├─ Sum{axes=None} [id H] 15
 │  └─ Sqr [id I] 13
 │     └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] 10
 │        ├─ [[1. 0. 0. ... 0. 0. 1.]] [id K]
 │        └─ AdvancedSetSubtensor [id L] 6
 │           ├─ Alloc [id M] 1
 │           │  ├─ 0.0 [id N]
 │           │  ├─ 3 [id O]
 │           │  └─ 3 [id O]
 │           ├─ AdvancedIncSubtensor1{no_inplace,set} [id P] 4
 │           │  ├─ Sigma_cholesky-cov__ [id F]
 │           │  ├─ Exp [id Q] 2
 │           │  │  └─ AdvancedSubtensor1 [id E] 0
 │           │  │     └─ ···
 │           │  └─ [0 2 5] [id G]
 │           ├─ [0 1 1 2 2 2] [id R]
 │           └─ [0 0 1 0 1 2] [id S]
 └─ Sum{axes=None} [id T] 12
    └─ Log [id U] 9
       └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=True} [id V] 7
          └─ AdvancedSetSubtensor [id L] 6
             └─ ···
# dlogp
AdvancedIncSubtensor1{inplace,inc} [id W] 'Sigma_cholesky-cov___grad' 23
 ├─ AdvancedIncSubtensor1{inplace,set} [id X] 21
 │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id Y] 19
 │  │  ├─ Add [id Z] 18
 │  │  │  ├─ AdvancedSetSubtensor [id BA] 11
 │  │  │  │  ├─ Alloc [id M] 1
 │  │  │  │  │  └─ ···
 │  │  │  │  ├─ Reciprocal [id BB] 8
 │  │  │  │  │  └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=True} [id V] 7
 │  │  │  │  │     └─ ···
 │  │  │  │  ├─ [0 1 2] [id BC]
 │  │  │  │  └─ [0 1 2] [id BC]
 │  │  │  └─ SolveTriangular{unit_diagonal=False, lower=False, b_ndim=2, overwrite_b=True} [id BD] 16
 │  │  │     ├─ [[1. 0. 0. ... 0. 0. 1.]] [id K]
 │  │  │     └─ Neg [id BE] 14
 │  │  │        └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] 10
 │  │  │           └─ ···
 │  │  ├─ [0 1 1 2 2 2] [id R]
 │  │  └─ [0 0 1 0 1 2] [id S]
 │  ├─ [0. 0. 0.] [id BF]
 │  └─ [0 2 5] [id G]
 ├─ Composite{((i1 * i2) + i0)} [id BG] 22
 │  ├─ [4. 3. 2.] [id D]
 │  ├─ AdvancedSubtensor1 [id BH] 20
 │  │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id Y] 19
 │  │  │  └─ ···
 │  │  └─ [0 2 5] [id G]
 │  └─ Exp [id Q] 2
 │     └─ ···
 └─ [0 2 5] [id G]

Inner graphs:

Composite{(2.079441547393799 + (0.5 * ((-14.159198660192542 + (2.0 * i2)) - i1)) + i0)} [id A]
 ← add [id BI]
    ├─ 2.079441547393799 [id BJ]
    ├─ mul [id BK]
    │  ├─ 0.5 [id BL]
    │  └─ sub [id BM]
    │     ├─ add [id BN]
    │     │  ├─ -14.159198660192542 [id BO]
    │     │  └─ mul [id BP]
    │     │     ├─ 2.0 [id BQ]
    │     │     └─ i2 [id BR]
    │     └─ i1 [id BS]
    └─ i0 [id BT]

Composite{((i1 * i2) + i0)} [id BG]
 ← add [id BU]
    ├─ mul [id BV]
    │  ├─ i1 [id BS]
    │  └─ i2 [id BR]
    └─ i0 [id BT]

1. Diagonal gradient routed through an (n, n) scatter

AdvancedSubtensor{idx_list=(0, 1)} [id Y]        ← read length-n(n+1)/2 packed lower-tri
 ├─ Add [id Z]                                    ← add (n,n) matrices
 │  ├─ AdvancedSetSubtensor [id BA]              ← scatter 1/L_kk onto diag of (n,n) zeros
 │  │  ├─ Alloc [id M]                            ← (n,n) zeros
 │  │  ├─ Reciprocal [id BB]                      ← 1/L_kk, length n
 │  │  ├─ [0 1 2]
 │  │  └─ [0 1 2]
 │  └─ SolveTriangular{lower=False} [id BD]       ← full (n,n) −V⁻¹·L gradient term
 ├─ [0 1 1 2 2 2]
 └─ [0 0 1 0 1 2]

The two gradient contributions (1/L_kk on the diagonal, −V⁻¹ L everywhere) are added as full (n, n) matrices, then only the n(n+1)/2 lower-triangular entries are read out. The upper triangle is dead work. The diagonal scatter writes n values into zeros solely so they align with the (n, n) layout of the triangular-solve result.

2. Extracting the diagonal we just placed

ExtractDiag [id V]
 └─ AdvancedSetSubtensor [id L]      ← scatter packed vec into (n,n) zeros
    ├─ Alloc [id M]                   ← (n,n) zeros
    ├─ AdvancedIncSubtensor1 [id P]  ← packed vec with Exp on diag slots
    │  ├─ Sigma_cholesky-cov__ [id F]
    │  ├─ Exp [id Q]                  ← exp(unc[diag_idxs]) = diag(L)
    │  └─ [0 2 5]
    ├─ [0 1 1 2 2 2]
    └─ [0 0 1 0 1 2]

ExtractDiag(L) recovers exactly Exp [id Q], the values we scattered onto the diagonal in the first place. Recognizing this identity simplifies both consumers:

  • logp: Sum(Log(ExtractDiag(L))) = Sum(Log(Exp(unc[diag_idxs]))) = Sum(unc[diag_idxs]).
    The Log, ExtractDiag, and Exp all cancel.

  • dlogp: Reciprocal(ExtractDiag(L)) = 1/Exp(unc[diag_idxs]).
    This is multiplied by L_kk = Exp(unc[diag_idxs]) via the chain rule in the per-diagonal Composite, giving (1/L_kk) · L_kk = 1. The diagonal gradient from the log-det term becomes a constant +1 per diagonal slot, absorbable into the existing log-Jacobian coefficients [n+1, n, …, 2][n+2, n+1, …, 3].

3. Set-then-inc on the same diagonal positions

AdvancedIncSubtensor1{inc} [id W]         ← inc at [0 2 5]
 ├─ AdvancedIncSubtensor1{set} [id X]     ← set [0 2 5] to zero
 │  ├─ [id Y]                              ← length-n(n+1)/2 packed gradient (from §1)
 │  ├─ [0. 0. 0.]
 │  └─ [0 2 5]
 ├─ Composite{(i1 * i2) + i0} [id BG]    ← length-n diagonal contribution
 │  ├─ [4. 3. 2.]                          ← log-Jacobian coefficients
 │  ├─ AdvancedSubtensor1 [id BH]         ← diag slice of [id Y] (read before zeroing)
 │  └─ Exp [id Q]                          ← L_kk
 └─ [0 2 5]

The set zeros the diagonal slots; the inc overwrites them with (Y[diag_idxs] · L_kk) + [4, 3, 2]. The zeroing is an autodiff artifact. With §1–§2 applied, Y[diag_idxs] at the diagonal simplifies (the Reciprocal scatter becomes a constant), and the entire set-then-inc collapses to a single inc_subtensor of one fused length-n vector.

4. Structural lower bound after all three simplifications

AdvancedIncSubtensor1{inc}                ← single inc at packed diag positions
 ├─ AdvancedSubtensor{idx_list=(0, 1)}    ← length-n(n+1)/2 packed lower-tri of −V⁻¹·L
 │  ├─ SolveTriangular{lower=False}       ← (n,n), unchanged
 │  ├─ [0 1 1 2 2 2]
 │  └─ [0 0 1 0 1 2]
 ├─ Composite                             ← length-n fused diagonal contribution
 │  ├─ [n+2, n+1, …, 3]                   ← merged log-Jacobian + log-det constant
 │  └─ Exp(unc[diag_idxs])                ← shared with L construction
 └─ [0 2 5]

What's already good

  • L built once, used three times: forward SolveTriangular, ExtractDiag, and gradient's second triangular solve.
  • Single Alloc: the (n, n) zero buffer is shared between L construction and the §1 diagonal-scatter.
  • Forward solve shared: M = L⁻¹ V serves both ‖M‖²_F (trace term) and −Lᵀ \ M (gradient term).
  • Unconstrained diagonal shared: unc[diag_idxs] and its Exp are CSE'd between L construction and the gradient's chain-rule factor.
  • No Cholesky op: the cholesky_ldotlt rewrite has already eliminated the chol(L Lᵀ) round trip.

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented Apr 9, 2026

With a few general rewrites, got it down to this form:

Composite{(2.079441547393799 + (0.5 * ((-35.37021969551046 + (2.0 * i2)) - i1)) + i0)} [id A] shape=() d={0: [0]} 13
 ├─ Sum{axes=None} [id B] shape=() 6
 │  └─ Mul [id C] shape=(?,) 4
 │     ├─ [4. 3. 2.] [id D] shape=(3,)
 │     └─ AdvancedSubtensor1 [id E] shape=(3,) 0
 │        ├─ Sigma_cholesky-cov__ [id F] shape=(?,)
 │        └─ [0 2 5] [id G] shape=(3,)
 ├─ Sum{axes=None} [id H] shape=() 11
 │  └─ Sqr [id I] shape=(3, 3) 9
 │     └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] shape=(3, 3) d={0: [1]} 8
 │        ├─ [[ 1.85564 ... 43760577]] [id K] shape=(3, 3)
 │        └─ AdvancedSetSubtensor [id L] shape=(3, 3) d={0: [0]} 7
 │           ├─ Alloc [id M] shape=(3, 3) 1
 │           │  ├─ 0.0 [id N] shape=()
 │           │  ├─ 3 [id O] shape=()
 │           │  └─ 3 [id O] shape=()
 │           ├─ AdvancedIncSubtensor1{no_inplace,set} [id P] shape=(?,) 5
 │           │  ├─ Sigma_cholesky-cov__ [id F] shape=(?,)
 │           │  ├─ Exp [id Q] shape=(?,) 2
 │           │  │  └─ AdvancedSubtensor1 [id E] shape=(3,) 0
 │           │  │     └─ ···
 │           │  └─ [0 2 5] [id G] shape=(3,)
 │           ├─ [0 1 1 2 2 2] [id R] shape=(6,)
 │           └─ [0 0 1 0 1 2] [id S] shape=(6,)
 └─ Sum{axes=None} [id T] shape=() 3
    └─ AdvancedSubtensor1 [id E] shape=(3,) 0
       └─ ···
AdvancedIncSubtensor1{inplace,set} [id U] shape=(6,) 'Sigma_cholesky-cov___grad' d={0: [0]} 17
 ├─ AdvancedSubtensor{idx_list=(0, 1)} [id V] shape=(6,) 14
 │  ├─ SolveTriangular{unit_diagonal=False, lower=False, b_ndim=2, overwrite_b=True} [id W] shape=(3, 3) d={0: [1]} 12
 │  │  ├─ [[ 1.85564 ... 43760577]] [id X] shape=(3, 3)
 │  │  └─ Neg [id Y] shape=(3, 3) d={0: [0]} 10
 │  │     └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] shape=(3, 3) d={0: [1]} 8
 │  │        └─ ···
 │  ├─ [0 1 1 2 2 2] [id R] shape=(6,)
 │  └─ [0 0 1 0 1 2] [id S] shape=(6,)
 ├─ Composite{(((i2 + reciprocal(exp(i3))) * i1) + i0)} [id Z] shape=(3,) d={0: [2]} 16
 │  ├─ [4. 3. 2.] [id D] shape=(3,)
 │  ├─ Exp [id Q] shape=(?,) 2
 │  │  └─ ···
 │  ├─ AdvancedSubtensor1 [id BA] shape=(3,) 15
 │  │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id V] shape=(6,) 14
 │  │  │  └─ ···
 │  │  └─ [0 2 5] [id G] shape=(3,)
 │  └─ AdvancedSubtensor1 [id E] shape=(3,) 0
 │     └─ ···
 └─ [0 2 5] [id G] shape=(3,)

Inner graphs:

Composite{(2.079441547393799 + (0.5 * ((-35.37021969551046 + (2.0 * i2)) - i1)) + i0)} [id A] d={0: [0]}
 ← add [id BB] shape=()
    ├─ 2.079441547393799 [id BC] shape=()
    ├─ mul [id BD] shape=()
    │  ├─ 0.5 [id BE] shape=()
    │  └─ sub [id BF] shape=()
    │     ├─ add [id BG] shape=()
    │     │  ├─ -35.37021969551046 [id BH] shape=()
    │     │  └─ mul [id BI] shape=()
    │     │     ├─ 2.0 [id BJ] shape=()
    │     │     └─ i2 [id BK] shape=()
    │     └─ i1 [id BL] shape=()
    └─ i0 [id BM] shape=()

Composite{(((i2 + reciprocal(exp(i3))) * i1) + i0)} [id Z] d={0: [2]}
 ← add [id BN] shape=()
    ├─ mul [id BO] shape=()
    │  ├─ add [id BP] shape=()
    │  │  ├─ i2 [id BK] shape=()
    │  │  └─ reciprocal [id BQ] shape=()
    │  │     └─ exp [id BR] shape=()
    │  │        └─ i3 [id BS] shape=()
    │  └─ i1 [id BL] shape=()
    └─ i0 [id BM] shape=()

So as good as I can think of

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants