You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When bergson build is run with include_bias=True and the Hessian step uses method='kfac', apply_hessian (and the scoring paths downstream of it) fails at a per-layer reshape: the gradient store has one extra column per layer that the K-FAC activation covariance doesn't account for.
This is a pre-existing limitation independent of the K-FAC compression work in #275 — the same failure happens on main. Filing separately so the fix doesn't get tangled with the compression PR.
What goes wrong (concrete shapes)
For each linear layer nn.Linear(I, O, bias=True):
bergson build with include_bias=True stores per-sample gradients of shape [O, I+1] per layer — the bias gradient is concatenated as an extra "activation" column. (HookCollectorBase.shapes() at bergson/collector/collector.py:264-270 sets grad_shape[-1] += 1 when collect_bias; _compute_gradient does the matching torch.cat.)
K-FAC CovarianceCollector at bergson/hessians/kfac.py computes A = aᵀa from the raw forward input a: [N·S, I] — no bias column. The collect_bias flag is unpacked in _init_covariance_dict (bergson/hessians/sharded_computation.py:28) but never used. Result: A: [I, I].
At apply time, compute_ivhp_sharded reshapes the loaded query gradient via:
gradients_noi.view(-1, eigen_g[k].shape[1], eigen_a[k].shape[1])
# = view(-1, O, I)# but stored flat size is N·O·(I+1) → reshape error
Repro
bergson build --processor.include_bias true ...
bergson approximate-hessians --hessian_cfg.method kfac ...
bergson apply-hessian ... # raises at .view(-1, O, I)
Fix sketch
Teach K-FAC's covariance collection to operate on the augmented activation[a; 1] of shape [N·S, I+1] when the layer's bias is being collected:
CovarianceCollector.forward_hook appends a 1-column to a (matching the build-time gradient layout) before aᵀa, giving A: [I+1, I+1].
_init_covariance_dict sizes the activation covariance as [I+1, I+1] when collect_bias=True.
Downstream (compute_eigendecomposition, compute_whitening_projection_matrices, apply_hessian) all derive d_A from A.shape[-1], so they pick up the new dimension automatically.
Found/generated by Claude during PR 275.
Summary
When
bergson buildis run withinclude_bias=Trueand the Hessian step usesmethod='kfac',apply_hessian(and the scoring paths downstream of it) fails at a per-layer reshape: the gradient store has one extra column per layer that the K-FAC activation covariance doesn't account for.This is a pre-existing limitation independent of the K-FAC compression work in #275 — the same failure happens on
main. Filing separately so the fix doesn't get tangled with the compression PR.What goes wrong (concrete shapes)
For each linear layer
nn.Linear(I, O, bias=True):bergson buildwithinclude_bias=Truestores per-sample gradients of shape[O, I+1]per layer — the bias gradient is concatenated as an extra "activation" column. (HookCollectorBase.shapes()atbergson/collector/collector.py:264-270setsgrad_shape[-1] += 1whencollect_bias;_compute_gradientdoes the matchingtorch.cat.)CovarianceCollectoratbergson/hessians/kfac.pycomputesA = aᵀafrom the raw forward inputa: [N·S, I]— no bias column. Thecollect_biasflag is unpacked in_init_covariance_dict(bergson/hessians/sharded_computation.py:28) but never used. Result:A: [I, I].compute_ivhp_shardedreshapes the loaded query gradient via:Repro
Fix sketch
Teach K-FAC's covariance collection to operate on the augmented activation
[a; 1]of shape[N·S, I+1]when the layer's bias is being collected:CovarianceCollector.forward_hookappends a 1-column toa(matching the build-time gradient layout) beforeaᵀa, givingA: [I+1, I+1]._init_covariance_dictsizes the activation covariance as[I+1, I+1]whencollect_bias=True.compute_eigendecomposition,compute_whitening_projection_matrices,apply_hessian) all derived_AfromA.shape[-1], so they pick up the new dimension automatically.The gradient covariance
S: [O, O]is unchanged.Scope
main(legacy IVHP path).