機器之心整理

參與:思、Jamin

一直以來,自動微分都在 DL 框架背後默默地運行着,本文希望探討它到底是什麼,通過 JAX,自動微分又能怎麼用。

自動微分現在已經是深度學習框架的標配,我們寫的任何模型都需要靠自動微分機制分配模型損失信息,從而更新模型。在廣闊的科學世界中,自動微分也是必不可少的。說到底,大多數算法都是由基本數學運算與基本函數組建的。
在 ICLR 2020 的一篇 Oral 論文中(滿分 8/8/8),圖賓根大學的研究者表示,目前深度學習框架中的自動微分模塊只會計算批量數據反傳梯度,但批量梯度的方差、海塞矩陣等其它量也很重要,它們可以在計算梯度的過程中快速算出來。
目前自動微分框架只計算出梯度,因此就限定了研究方向只能放在梯度下降變體之上,而不能做更廣的探討。爲此,研究者構建了 BACKPACK,它建立在 PyTorch 之上,還擴展了自動微分與反向傳播能獲得的信息。

自動微分到底是什麼?這裏有一份自我簡述

選自論文 BACKPACK,arXiv:1912.10985。
除此之外,Julia Computing 團隊去年 7 月份也發表了一份論文,提出了可微編程系統,它能將自動微分內嵌於 Julia 語言,從而將其作爲第一級的語言特性。由於廣泛的科學計算和機器學習領域都需要線性代數的支持,因此這種可微編程能成爲更加通用的一種模式。
從這些前沿研究可以清晰地感受到,自動微分越來越重要。
自動微分是什麼
在數學與計算代數學中,自動微分也被稱爲微分算法或數值微分。它是一種數值計算的方式,用來計算因變量對某個自變量的導數。此外,它還是一種計算機程序,與我們手動計算微分的「分析法」不太一樣。
自動微分基於一個事實,即每一個計算機程序,不論它有多麼複雜,都是在執行加減乘除這一系列基本算數運算,以及指數、對數、三角函數這類初等函數運算。通過將鏈式求導法則應用到這些運算上,我們能以任意精度自動地計算導數,而且最多隻比原始程序多一個常數級的運算。

自動微分到底是什麼?這裏有一份自我簡述

一般而言會存在兩種不同的自動微分模式,即前向累積梯度(前向模式)和反向累計梯度(反向模式)。前向累積會指定從內到外的鏈式法則遍歷路徑,即先計算 d_w1/d_x,再計算 d_w2/d_w1,最後計算 dy/dw_2。
反向梯度累積正好相反,它會先計算 dy/dw_2,然後計算 d_w2/d_w1,最後計算 d_w1/d_x。這是我們最爲熟悉的反向傳播模式,它非常符合「沿模型誤差反向傳播」這一直觀思路。

自動微分到底是什麼?這裏有一份自我簡述

如圖所示,兩種自動微分模式都在求 dy/dx,只不過根據鏈式法則展開的形式不太一樣。
來一個實例:誤差傳播
在統計學上,由於變量含有誤差,使得函數也含有誤差,我們將其稱之爲誤差傳播。闡述這種關係的定律叫做誤差傳播定律。
先定義一個函數 q(x,y) ,我們想通過 q 傳遞 x 與 y 的不確定性信息,即 𝜎_x 與 𝜎_y。最直接的方式是隨機採樣 x 與 y,並計算 q 的值,然後查看它的分佈。這就是「傳播不確定性」這個概念的意義。
誤差傳播的積分公式可以是一個近似值, q(x,y) 的一般表達式可以寫爲:

自動微分到底是什麼?這裏有一份自我簡述

如果我們定義一個特殊案例,即 q(x,y)=x±y,那麼總不確定性可以寫爲:

自動微分到底是什麼?這裏有一份自我簡述

對於特例 q(x,y)=xy 與 q(x,y)=x/y ,不確定性分別爲 (σ_q/q)^2 = (σ_x/x)^2+(σ_y/y)^2 與 σ_q=(x/y)* sqrt((σ_x/x)^2+(σ_y/y)^2)。
我們可以嘗試這些方法,並對比根據這些近似公式算出來的反傳誤差,以及實際發生的反傳誤差。
實戰 JAX 自動微分
Jax 是谷歌開源的一個科學計算庫,能對 Python 程序與 NumPy 運算執行自動微分,而且能夠在 GPU 和 TPU 上運行,具有很高的性能。
如下先導入 JAX,然後用三行代碼就能定義之前給出的反傳不確定性度量。

    from jax *import* grad, jacfwd  
    import jax.numpy *as* np  

    def error_prop_jax_gen(q,x,dx):  
        jac = jacfwd(q)  
        return np.sqrt(np.sum(np.power(jac(x)*dx,2)))

自動微分到底是什麼?這裏有一份自我簡述

這裏計算的反傳梯度是根據 jax 完成的,後面的反傳誤差會直接通過公式計算,並對比兩者。
1. 配置兩個具有不確定性的觀察值
我們需要使用 x 與 y 作爲符號推理,但可以把它們都儲存在數組 x 中,x[0]=x、x[1]=y。

    x_ = np.array([2.,3.])  
    dx_ = np.array([.1,.1])

2. 加減法
在 𝑞(𝑥,𝑦)=𝑥±𝑦 這一特例情況下,誤差傳播公式可以簡化爲

自動微分到底是什麼?這裏有一份自我簡述

自動微分到底是什麼?這裏有一份自我簡述

上圖所示,通過誤差傳播公式計算出來的值與 JAX 計算出來的是一致地。
3. 乘除法
在 𝑞(𝑥,𝑦)=𝑥𝑦 與 𝑞(𝑥,𝑦)=𝑥/𝑦 這兩種特例中,誤差傳播公式可以寫爲:

自動微分到底是什麼?這裏有一份自我簡述

自動微分到底是什麼?這裏有一份自我簡述

4. 冪
對於特例 𝑞(𝑥,𝑦)=𝑥^𝑚*𝑦^𝑛,傳播公式可以表示爲:

自動微分到底是什麼?這裏有一份自我簡述

我們可以寫成

自動微分到底是什麼?這裏有一份自我簡述

自動微分到底是什麼?這裏有一份自我簡述

JAX 的使用非常多樣,甚至能直接使用它搭建神經網絡。例如 JAXnet 框架,它是一個基於 JAX 的深度學習庫,它的 API 提供了便利的模型搭建體驗。比如說,以下代碼就能建個神經網絡:

    from jaxnet import *  

    net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax)

此外,不久之前,DeepMind 也發佈了兩個新庫:在 Jax 上進行面向對象開發 的 Haiku 和 Jax 上的強化學習庫 RLax。JAX 這樣的通用自動微分庫也許能在更廣泛的領域發揮作用。
參考閱讀:

自動微分到底是什麼?這裏有一份自我簡述

來源鏈接:mp.weixin.qq.com