从「无限深」到「一步之遥」

8月 21, 2025·
Gemini 2.5 Pro
Gemini 2.5 Pro
Yanbo Zhang
Yanbo Zhang
· 13 分钟阅读时长

说到深度学习,我们脑海里浮现的第一个画面,可能就是一层层堆叠起来的神经网络。从 AlexNet 的 8 层,到 ResNet 的 152 层,再到如今动辄上千层的 Transformer 变体,模型的「深度」似乎成了其强大能力的代名词。这种「堆叠」的哲学简单直观:就像盖楼一样,楼层越多,功能就越复杂、视野就越开阔。

但高楼并非想盖多高就盖多高。梯度消失 / 爆炸、巨大的内存开销(反向传播时需要存储每一层的激活值)等问题,都像是摩天大楼头顶的乌云。于是,研究者们开始思考一个有趣的问题:如果我们不一层层地盖新楼,而是拿着同一套图纸,对同一层楼反复地精装修,会发生什么?

这就是权重共享(Weight-Tied)或循环网络(Recurrent Network)的思想。在这种模式下,输入信号会在同一个模块里反复迭代,不断地自我更新。如果我们迭代足够多次,这个信号最终可能会达到一个「炉火纯青」的稳定状态——它不再发生任何变化,我们称之为不动点(Fixed Point)均衡点(Equilibrium)。这个不动点,就蕴含了模型在「无限深度」下所能提取的所有信息。

这个想法非常诱人,因为它意味着我们可能用一个很小的模型,通过多次迭代,就能模拟出一个无限深的网络。但新的问题也随之而来:

  1. 前向传播:我们真的要迭代成百上千次来找到那个不动点吗?这效率也太低了。
  2. 反向传播:如果真的迭代了 1000 次,那反向传播时岂不是要存储 1000 个中间状态?这内存开销简直是灾难,比堆叠 1000 个不同层还要命。

那么,有没有一种方法,能让我们像开了「上帝视角」一样,直接跳到最终的那个不动点,并且在训练时,只用一步反向传播就算出梯度呢?答案是肯定的,而这把钥匙,就藏在数学中的一个强大工具——**隐函数定理(Implicit Function Theorem, IFT)**里。

新的起点:从「反复打磨」说起

为了理解这个过程,我们先来构思一个生动的类比。想象一下,我们是一位玉雕大师,手里的输入数据 $x$ 是一块璞玉,我们的目标是把它雕琢成一件精美的玉器(最终的输出)。我们的神经网络 $f_\theta$ 就是一套雕刻工具和技法,参数 $\theta$ 代表了工具的锋利程度和技法的精妙程度。

传统的深度网络,就像是一条流水线,有几十道工序,每道工序(一层网络)都用不同的工具对玉石进行一次加工。而我们现在讨论的这种循环迭代模型,则更像是一位老工匠,他只用一套自己最称手的工具,对着这块璞玉反复打磨、精雕细琢。

我们用 $z_i$ 表示第 $i$ 次打磨后的玉石状态。那么整个过程可以写成:

$$ z_{i+1} = f_\theta(z_i, x) $$

这里,我们把原始的璞玉信息 $x$ 也作为每次打磨的参考,以确保不会偏离主题。经过无数次打磨,玉石的状态会越来越完美,最终达到一个「再多一刀都嫌多,再少一刀都嫌少」的境界,这就是不动点 $z^*$。在这个状态下,我们的雕刻技法已经无法再对其进行任何改进,即:

$$ z^* = f_\theta(z^*, x) $$

这,就是我们梦寐以求的最终输出。现在,我们的核心问题变成了:如何高效地找到 $z^*$,以及如何根据 $z^*$ 的好坏来调整我们的技法 $\theta$ ?

笨办法:一步一脚印地「回看」

最直观的训练方法是什么?那就是老老实实地记录下每一次打磨的过程。

  1. 前向传播:从 $z_0$ (比如就是输入 $x$)开始,迭代 $N$ 次,得到 $z_1, z_2, \dots, z_N$。我们假设 $z_N$ 已经非常接近不动点 $z^*$ 了。
  2. 反向传播:计算最终得到的 $z_N$ 和我们心中完美玉器(真实标签 $y$)的差距(损失函数 $\mathcal{L}$)。然后,就像放电影倒带一样,从 $z_N$ 开始,一步步地将梯度反向传播到 $z_{N-1}$,再到 $z_{N-2}$,……,一直传回 $z_0$。这个过程就是大名鼎鼎的时间反向传播(Backpropagation Through Time, BPTT)

这个方法的致命缺陷在于,为了「倒带」,我们必须把每一帧的画面($z_1, \dots, z_N$)都存下来。当迭代次数 $N$ 很大时,内存占用会成为一个天文数字。这就像是,为了优化雕刻技法,工匠必须记住自己刚刚挥出的成百上千刀的每一个细节,这显然是不现实的。

神来之笔:「抄近道」的隐函数定理

我们真的需要关心中间过程吗?其实不必。我们只关心最终的作品 $z^*$ 和我们的技法 $\theta$ 之间的关系。换句话说,我们想知道,如果我的刻刀(参数 $\theta$)稍微变锋利一点点,最终的玉器(不动点 $z^*$)会发生什么样的变化?

隐函数定理(IFT)正是回答这个问题的神器。

我们再来看一下不动点的定义方程:

$$ z^* - f_\theta(z^*, x) = 0 $$

这个方程定义了一个关于 $\theta$ 和 $z^*$ 的隐式关系。它没有像 $z^* = g(\theta)$ 那样直接写出 $z^*$ 是 $\theta$ 的什么函数,而是通过一个平衡方程把它们联系在了一起。IFT告诉我们,即使我们不知道这个显式函数 $g(\cdot)$ 是什么,我们依然可以直接求出它的导数 $\cfrac{\text d z^*}{\text d \theta}$ !

推导过程虽然有点数学,但思想非常直观。我们假设参数 $\theta$ 发生了一个无穷小的变化 $\text d\theta$,这会导致不动点也发生一个无穷小的变化 $\text dz^*$。但即使变化后,新的不动点 $z^*+\text dz^*$ 和新参数 $\theta+\text d\theta$ 依然要满足那个平衡方程。我们对上面那个平衡方程两边关于 $\theta$ 求全导数:

$$ \frac{\text d}{\text d\theta} \left( z^* - f_\theta(z^*, x) \right) = \frac{\text d}{\text d\theta}(0) $$

利用链式法则展开左边,我们得到:

$$ \frac{\text d z^*}{\text d \theta} - \left( \frac{\partial f_\theta}{\partial z^*} \frac{\text d z^*}{\text d \theta} + \frac{\partial f_\theta}{\partial \theta} \right) = 0 $$

这是一个关于我们想要的导数 $\cfrac{\text d z^*}{\text d \theta}$ 的线性方程!我们把它整理一下:

$$ \left( I - \frac{\partial f_\theta}{\partial z^*} \right) \frac{\text d z^*}{\text d \theta} = \frac{\partial f_\theta}{\partial \theta} $$

这里的 $I$ 是单位矩阵,$\cfrac{\partial f_\theta}{\partial z^*}$ 是函数 $f_\theta$ 对其第一个输入(即上一步的状态)的雅可比矩阵,我们记为 $J_f$。解这个方程,我们就得到了:

$$ \frac{\text d z^*}{\text d \theta} = \left( I - J_f \right)^{-1} \frac{\partial f_\theta}{\partial \theta} $$

这还没完,我们最终关心的是损失函数 $\mathcal{L}$ 对参数 $\theta$ 的梯度。再用一次链式法则:

$$ \frac{\partial \mathcal{L}}{\partial \theta} = \frac{\partial \mathcal{L}}{\partial z^*} \frac{\text d z^*}{\text d \theta} = \frac{\partial \mathcal{L}}{\partial z^*} \left( I - J_f \right)^{-1} \frac{\partial f_\theta}{\partial \theta} $$

看!这就是那个神奇的「一步」梯度公式!

我们来解读一下这个公式的非凡之处:

  1. 告别BPTT:公式的计算只依赖于最终的不动点 $z^*$,完全不涉及中间的迭代步骤 $z_1, z_2, \dots$。
  2. 内存恒定:我们只需要存储 $z^*$ 这一个状态,就可以计算出完整的梯度。内存消耗从 $\mathcal{O}(N)$ 骤降到 $\mathcal{O}(1)$ !
  3. 前向后向解耦:前向传播(如何找到 $z^*$)和反向传播(如何用 $z^*$ 计算梯度)被分开了。前向我们可以用任何高效的求根算法(比如 Broyden 法),而后向则直接套用这个公式。

这就像我们发明了一种魔法,可以直接分析最终的玉器成品 $z^*$,并瞬间推算出我们的技法 $\theta$ 应该如何改进,而完全无需回忆整个雕刻过程。这就是**深度均衡模型(Deep Equilibrium Models, DEQ)**的核心思想。

又一个拦路虎:矩阵求逆

天下没有免费的午餐。IFT虽然帮我们绕过了 BPTT 的内存地狱,但它给我们留下了一个新的挑战:公式中那个巨大的雅可比矩阵的逆 $(I - J_f)^{-1}$。

在神经网络中,$z^*$ 的维度可能高达数百万,这意味着 $J_f$ 是一个百万乘百万的矩阵。直接计算它的逆,计算复杂度高达 $\mathcal{O}(d^3)$($d$ 是 $z^*$ 的维度),这在计算上是不可行的。

怎么办?幸运的是,我们通常不需要计算出完整的逆矩阵。我们真正需要的是向量 $\cfrac{\partial \mathcal{L}}{\partial z^*}$ 和矩阵 $(I-J_f)^{-1}$ 的乘积。这是一个经典的线性方程组求解问题,可以用共轭梯度法等迭代方法高效求解。

不过,研究者们发现了一条更简洁的路,这条路在《On Training Implicit Models》和《Hierarchical Reasoning Model》论文中都有体现,那就是近似

走捷径的捷径:诺依曼级数近似

还记得高中数学里的等比数列求和吗?当公比 $|q|<1$ 时,有 $\sum_{i=0}^\infty q^i = \cfrac{1}{1-q}$。这个思想可以推广到矩阵上,就得到了诺依曼级数(Neumann series)

如果矩阵 $A$ 的谱半径小于 $1$,那么:

$$ (I - A)^{-1} = I + A + A^2 + A^3 + \dots $$

把这个用在我们的问题上,令 $A = J_f$,我们就得到了:

$$ (I - J_f)^{-1} \approx I + J_f + J_f^2 + \dots + J_f^{k-1} $$

我们可以只取这个级数的前 $k$ 项来近似那个该死的矩阵逆!

  • 当 $k=1$ 时,$(I - J_f)^{-1} \approx I$。这被称为单步梯度(one-step gradient)。梯度公式简化为 $\cfrac{\partial \mathcal{L}}{\partial \theta} \approx \cfrac{\partial \mathcal{L}}{\partial z^*} \cfrac{\partial f_\theta}{\partial \theta}$。这是最快但可能最不准的近似。
  • 当 $k > 1$ 时,我们得到的就是所谓的幻影梯度(phantom gradient)。它在计算成本和梯度精度之间提供了一个灵活的权衡。$k$ 越大,近似越准,但计算量也越大。

实践中,我们甚至不需要显式地计算 $J_f$ 的幂,而是通过 $k$ 次雅可比-向量积的迭代来高效地计算最终结果。这使得整个反向传播过程变得极其轻快。

前向过程

我们已经知道,对于不动点 $z^* = f_\theta(z^*, x)$,我们可以通过隐函数定理(IFT)推导出一个解析的梯度表达式。然而,在实践中,我们更希望利用 PyTorch 这类自动微分框架的便利性,而不是去手动实现复杂的梯度计算。这就引出了一个核心问题:既然我们已经有了目标梯度公式,我们能否设计一个与之等效的前向传播过程,让自动微分引擎在反向求导时,能「自动」地计算出我们想要的那个梯度?

答案是肯定的。要理解这一点,首先需要回顾自动微分的本质。自动微分(Autodiff)并非进行符号微分,而是对一个已发生的前向计算过程应用链式法则。它首先在前向传播时,将每一步操作记录下来,构建成一个计算图(Computational Graph)。在反向传播时,它会沿着这个图反向追溯,机械地、一步步地应用链式法则。因此,前向计算图的结构,完全决定了反向传播时计算出的梯度是什么。

我们的任务,就是「投其所好」,为我们手中的近似梯度公式,「伪造」一个对应的前向计算图。

单步梯度 (One-Step Gradient)

我们先从最简单的单步梯度近似开始。这种近似来源于将诺依曼级数 $(I - J_f)^{-1} = I + J_f + J_f^2 + \dots$ 截断到第一项,即 $(I - J_f)^{-1} \approx I$。将它代入完整的 IFT 梯度公式 $\cfrac{\partial \mathcal{L}}{\partial \theta} = \cfrac{\partial \mathcal{L}}{\partial z^*} (I - J_f)^{-1} \cfrac{\partial f_\theta}{\partial \theta}$,我们得到的目标梯度是:

$$ \frac{\partial \mathcal{L}}{\partial \theta} \approx \frac{\partial \mathcal{L}}{\partial z^*} \frac{\partial f_\theta}{\partial \theta} $$

现在我们来分析这个梯度。根据链式法则,它其实等价于对这样一个函数 $g(\theta) = f_\theta(z^*, x)$ 求梯度,并且在求导过程中,必须将 $z^*$ 视为一个与 $\theta$ 无关的常数

如何让自动微分引擎将 $z^*$ 视为常数呢?方法很简单:在计算 $z^*$ 的过程中,关闭梯度追踪即可。这就引出了一个清晰的两阶段前向构造方法,这个技巧在 HRM 模型的伪代码中有清晰的体现。

 1# 目标:构建一个前向过程,其自动微分梯度等价于单步梯度
 2
 3def forward_one_step(x, params):
 4    # 阶段一:在不追踪梯度的环境中,通过迭代求解器找到不动点z*
 5    # 这里的操作不会被记录到计算图中
 6    with torch.no_grad():
 7        z_star = solver(func, x, params)
 8
 9    # 阶段二:在正常的、追踪梯度的环境中,
10    # 将z_star(现在被视为一个常数输入)进行一次变换
11    z_final = func(z_star, x, params)
12
13    return z_final

这个前向过程构建的计算图非常小。它只记录了从z_starz_final的这一步变换。由于z_star是在no_grad环境中计算的,它不包含任何关于params的历史信息。因此,当自动微分引擎反向传播时,梯度只会通过这最后一步,完美地计算出了我们想要的

展开式幻影梯度 (UPG)

单步梯度虽然实现简单,但其近似程度较为粗糙。为了获得更精确的梯度估计,我们可以采用展开式幻影梯度(UPG),它对应于诺依曼级数中更长的截断,从而在梯度精度和计算成本之间取得了更好的平衡。

从数学上看,UPG的梯度被定义为对一个从不动点 $z^*$ 开始、迭代 $k$ 次的序列进行微分的结果,同时依然严格遵守一个核心前提:将起点 $z^*$ 视为一个与模型参数 $\theta$ 无关的常数

为了清晰地看到这个梯度到底是什么,我们首先构建这个长度为 $k$ 的计算序列。设序列的起点为 $z^{(0)} = z^*$,后续的每一项都通过函数 $f_\theta$ 进行更新:

$$ z^{(i)} = f_\theta(z^{(i-1)}, x) \quad \text{for} \quad i = 1, 2, \dots, k $$

最终的损失函数 $\mathcal{L}$ 是作用于这个序列的最后一项 $z^{(k)}$ 上的。我们想要求解的梯度是 $\cfrac{d\mathcal{L}}{d\theta}$。根据链式法则,我们有:

$$ \frac{\text d\mathcal{L}}{\text d\theta} = \frac{\partial \mathcal{L}}{\partial z^{(k)}} \frac{\text d z^{(k)}}{\text d\theta} $$

这里的关键在于计算 $\cfrac{\text d z^{(k)}}{\text d\theta}$。我们对迭代公式 $z^{(k)} = f_\theta(z^{(k-1)}, x)$ 两边关于 $\theta$ 求全导数:

$$ \frac{\text d z^{(k)}}{\text d\theta} = \frac{\partial f_\theta}{\partial z^{(k-1)}} \frac{\text d z^{(k-1)}}{\text d\theta} + \frac{\partial f_\theta}{\partial \theta} $$

这是一个递归式。我们可以将它反复展开。为了简化符号,我们记 $J_f^{(i-1)} = \cfrac{\partial f_\theta}{\partial z^{(i-1)}}$ 为在点 $z^{(i-1)}$ 处的雅可比矩阵。展开一层,我们得到:

$$ \frac{\text d z^{(k)}}{\text d\theta} = J_f^{(k-1)} \left( \frac{\partial f_\theta}{\partial z^{(k-2)}} \frac{\text d z^{(k-2)}}{\text d\theta} + \frac{\partial f_\theta}{\partial \theta} \right) + \frac{\partial f_\theta}{\partial \theta} = J_f^{(k-1)} J_f^{(k-2)} \frac{\text d z^{(k-2)}}{\text d\theta} + J_f^{(k-1)} \frac{\partial f_\theta}{\partial \theta} + \frac{\partial f_\theta}{\partial \theta} $$

持续这个过程,直到我们展开至起点 $z^{(0)}$。此时,我们应用核心前提:$z^{(0)} = z^*$ 被视为常数,因此它的导数 $\cfrac{\text d z^{(0)}}{\text d\theta} = 0$。这使得递归的尽头是一个零项,整个表达式完全展开后,只剩下与 $\cfrac{\partial f_\theta}{\partial \theta}$ 相关的项。

这个展开后的梯度形式,正是自动微分引擎在反向传播一个包含 $k$ 次函数调用的计算图时,通过链式法则所计算出的结果。因此,为了让自动微分引擎为我们计算 UPG,我们需要构建一个与之完全对应的前向过程。这个过程必须满足两个条件:1) 序列的起点是真正的不动点 $z^*$;2) 在求导时,这个起点没有关于 $\theta$ 的梯度历史。这引导我们再次使用no_grad技巧,这一方法在《On Training Implicit Models》中有清晰的论述和实现。

 1# 目标:构建一个前向过程,其自动微分梯度等价于UPG
 2
 3def forward_upg(x, params, k=5):
 4    # 阶段一:在不追踪梯度的环境中找到不动点z*
 5    # 这一步屏蔽了求解过程,确保z_star在计算图上没有「历史\
 6    with torch.no_grad():
 7        z_star = solver(func, x, params)
 8
 9    # 阶段二:将z_star作为起点,进行k次迭代
10    # 这一段循环的操作会被自动微分完整记录下来,构成一个长度为k的计算图
11    z_k_steps = z_star
12    for _ in range(k):
13        z_k_steps = func(z_k_steps, x, params)
14
15    return z_k_steps

这段代码构建的前向过程与我们的推导完美契合。solverno_grad环境中运行,保证了z_star在计算图意义上是一个常数。随后的for循环则在正常的梯度追踪环境中,构建了一个长度为 $k$ 的计算链条。当自动微分引擎对此反向求导时,它所执行的链式法则运算,其结果与我们手动推导的 UPG 梯度表达式完全一致,从而高效且准确地实现了我们的目标。

文章小结

本文从「无限深」网络的诱人前景出发,踏上了一段寻找高效训练方法的旅程。我们从最直观的循环迭代思想开始,很快就遇到了时间反向传播(BPTT)带来的内存与计算双重困境。这迫使我们思考一个更本质的问题:我们是否必须完整地「倒带」整个计算过程,才能对模型进行有效的优化?

答案是否定的。隐函数定理(IFT)如同一把钥匙,为我们打开了新世界的大门。它优雅地揭示了,我们无需关心模型是如何一步步达到「均衡」的,只需聚焦于最终的「不动点」状态本身,就能精确地推导出梯度。这一深刻的见解,将反向传播的内存开销从与迭代步数相关的 $\mathcal O(T)$,一举降低到了恒定的 $\mathcal O(1)$,从根本上解决了训练深度循环模型的最大瓶颈。

当然,理论的完美在实践中总会遇到新的挑战——巨大的雅可比矩阵求逆问题。但这并未阻碍我们前进的脚步,反而催生了更具工程智慧的解决方案。无论是简单直接的「单步梯度」,还是更为精妙的「展开式幻影梯度」(UPG),我们都学会了如何通过操控自动微分引擎的计算图,来「伪造」出与复杂数学推导等价的前向过程。这些「戏法」不仅让理论得以落地,更展示了现代深度学习框架的强大灵活性。

以 DEQ、HRM 为代表的隐式模型,为我们展示了超越传统「逐层堆叠」范式的一种可能。它们虽然在前向传播时需要迭代求解,不如前馈网络那般直接,但它们用恒定的内存开销,换来了模拟「无限深度」计算的潜力。这使得处理需要长时程、复杂推理的任务成为了可能,而这在过去是难以想象的。

它或许揭示了深度学习的一种新可能:真正的「深度」,或许不在于我们构建了多高的楼,而在于我们能否让系统在反复的自我审视与迭代中,最终收敛到一个深刻而自洽的「均衡」。