对于Neural ODE的小研究

 

上面就是用欧拉方法解常微分方程的代码。

●Midpoint method (or RK2) - 2nd order method方法只需

 

这里odeint是一种通用的ODE求解器,必须提供fun(t,ht),初始条件,评估函数的时间步和求解器

像Runge–Kutta(RK4)或Adams–Bashforth这样的高阶方法可以保证更好的数值精度

所有这些都可以在形式通用的接口中实现(例如scipy

将神经网络与ODE求解器集成

 

结果如图所示 

We can use existing (and efficient) implementation of solvers to integrate NNs dynamics
The memory cost is O(1) , due to reversibility i.e. we don’t need to store all activations in the graph, we can easily recover them by backward integration (i.e. time reversed integration)
Complex dynamics can be modeled with fewer parameters
We can control accuracy/speed trade-off with adaptive solvers by setting lower/higher error tolerances
Hidden states can be accessed at any value of t - no discrete time steps as in RestNet skip connection

 NeuralODE - adjoint method

Adjoint method can be understand as a continuous version of chain rule
Chain rule: Consider following sequence of operations ( L is a scalar loss):

We can compute gradient of L w.r.t input state using chain rule

 此公式是任何深度学习autograd的核心

 

©️2020 CSDN 皮肤主题: Age of Ai 设计师:meimeiellie 返回首页