对于Neural ODE的小研究

数学相关知识 专栏收录该内容
20 篇文章 0 订阅

 

上面就是用欧拉方法解常微分方程的代码。

●Midpoint method (or RK2) - 2nd order method方法只需

 

这里odeint是一种通用的ODE求解器,必须提供fun(t,ht),初始条件,评估函数的时间步和求解器

像Runge–Kutta(RK4)或Adams–Bashforth这样的高阶方法可以保证更好的数值精度

所有这些都可以在形式通用的接口中实现(例如scipy

将神经网络与ODE求解器集成

 

结果如图所示 

We can use existing (and efficient) implementation of solvers to integrate NNs dynamics
The memory cost is O(1) , due to reversibility i.e. we don’t need to store all activations in the graph, we can easily recover them by backward integration (i.e. time reversed integration)
Complex dynamics can be modeled with fewer parameters
We can control accuracy/speed trade-off with adaptive solvers by setting lower/higher error tolerances
Hidden states can be accessed at any value of t - no discrete time steps as in RestNet skip connection

 NeuralODE - adjoint method

Adjoint method can be understand as a continuous version of chain rule
Chain rule: Consider following sequence of operations ( L is a scalar loss):

We can compute gradient of L w.r.t input state using chain rule

 此公式是任何深度学习autograd的核心

 

  • 1
    点赞
  • 2
    评论
  • 1
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

相关推荐
©️2020 CSDN 皮肤主题: 博客之星2020 设计师:CY__ 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值