微分方程数值算法的误差分析

4264阅读 0评论2012-05-04 HyryStudio
分类:Python/Ruby

弹簧-质量-阻力系统

下图为弹簧-质量-阻力系统的示意图:

根据牛顿力学定理,很容易列出如下的微分方程:

.ddot{x} + { c .over m} .dot{x} + {k .over m} x = {F .over m}

这是一个二阶的微分方程,我们把它改写为如下的两个一阶微分方程组:

.dot{x} = v

.dot{v} = {{F - k x - c v} .over m}

下面是一些系统的一些参数,其中外力用f()表示,这样可以用来计算外力随时间变化的情况。

k = 3.0
c = 0.2
m = 0.1
f = lambda t: 10.0
init = 5.0, 0.0

mass_dump_spring()计算在时刻t、状态x, v为status时,.dot{x}.dot{v}的值。

def mass_dump_spring(status, t):
    x, v = status
    dx = v
    dv = (f(t) - k*x - c*v)/m
    return dx, dv

我们可以使用scipy.integrate中提供的odeint()对微分方程进行求解:

import numpy as np
from scipy.integrate import odeint
import pylab as pl

def solve_by_odeint(time, h):
    t = np.arange(0, time, h)
    result = odeint(mass_dump_spring, init, t)
    return t, result[:, 0], result[:, 1]
欧拉方法

欧拉方法,命名自它的发明者莱昂哈德·欧拉,是一种一阶数值方法,用以对给定初值的常微分方程(即初值问题)求解。

对于如下的微分方程:

y'(t) = f(t,y(t)), .qquad .qquad y(t_0)=y_0

可以通过如下的公式得到时刻t_{n+1} = t_{n} + h的近似值:

y_{n+1} = y_n + h f(t_n,y_n)

下面是实现欧拉方法的程序:

def add(status, dstatus, h):
    return [status[i] + dstatus[i]*h for i in xrange(len(status))]

def euler(func, status, time, h):
    tlist = np.arange(0, time, h).tolist()
    result = []
    for t in tlist:
        result.append(status)
        dstatus = func(status, t)
        status = add(status, dstatus, h)
    return tlist, np.array(result)

def solve_by_euler(time, h):
    t, result = euler(mass_dump_spring, init, time, h)
    return t, result[:, 0], result[:, 1]

为了支持任意长度的状态矢量,在add()中使用列表推导式计算”status + h * dstatus”的值。此函数在后面实现龙格-库塔法时还会用到。当h足够小时,欧拉方法所得到的值足够精确,但是随着h的增大,误差也会明显增加,下面的程序比较当h为0.01和0.001时的误差:

def euler_plot():
    for i, h in enumerate([0.01, 0.001]):
        pl.subplot(211 + i)
        t, x, v = solve_by_odeint(5, h)
        t, x_euler, v_euler = solve_by_euler(5, h)
        pl.plot(t, x, label="odeint")
        pl.plot(t, x_euler, "r", label="euler")
        pl.legend(loc="best")
        pl.title("h = %g" % h)

其结果如下图所示:

中点法

欧拉法的积累误差与h成正比,为了在较大的h时也能计算出较为精确的解,需要使用更高阶的算法。下面是中点法的计算公式,其积累误差与h^{2}成正比:

y_{n+1} = y_n + hf.left(t_n+.frac{1}{2}h,y_n+.frac{1}{2}hf(t_n, y_n).right)

下面是实现中点法的程序,程序中需要调用两次计算微分的函数:

def midpoint(func, status, time, h):
    tlist = np.arange(0, time, h).tolist()
    result = []
    for t in tlist:
        result.append(status)
        dstatus = func(status, t)
        status2 = add(status, dstatus, 0.5*h)
        dstatus2 = func(status2, t+0.5*h)
        status = add(status, dstatus2, h)
    return tlist, np.array(result)

def solve_by_midpoint(time, h):
    t, result = midpoint(mass_dump_spring, init, time, h)
    return t, result[:, 0], result[:, 1]
经典四阶龙格库塔法

龙格库塔法是欧拉方法和中点法的推广,其中最常用的是经典四阶龙格库塔法,通常被称为”RK4”,下面是其计算公式:

y_{n+1} = y_n + .tfrac{1}{6} .left(k_1 + 2k_2 + 2k_3 + k_4 .right)

其中的k_1, k_2, k_3, k_4等参数使用下面的公式运算,由公式可知需要调用4次计算微分的函数。

k_1 &= hf(t_n, y_n)

k_2 &= hf(t_n + .tfrac{1}{2}h , y_n +  .tfrac{1}{2} k_1)

k_3 &= hf(t_n + .tfrac{1}{2}h , y_n +   .tfrac{1}{2} k_2)

k_4 &= hf(t_n + h , y_n + k_3)

下面是实现RK4算法的程序:

def rk4(func, status, time, h):
    tlist = np.arange(0, time, h).tolist()
    h2 = 0.5*h
    result = []
    for t in tlist:
        result.append(status)
        k1 = func(status, t)
        k2 = func(add(status, k1, h2), t+h2)
        k3 = func(add(status, k2, h2), t+h2)
        k4 = func(add(status, k3, h), t+h)
        dstatus = [v1+2*v2+2*v3+v4 for (v1,v2,v3,v4) in zip(k1,k2,k3,k4)]
        status = add(status, dstatus, h/6)
    return tlist, np.array(result)

def solve_by_rk4(time, h):
    t, result = rk4(mass_dump_spring, init, time, h)
    return t, result[:, 0], result[:, 1]
误差比较

为了比较各种算法的积累误差与h之间的关系,我们让h为从0.001到0.1的等比数列,并计算各种算法的运算结果与odeint()的结果之间的误差和。

def error(func1, func2, time, h_list):
    ex = []
    ev = []
    for h in h_list:
        _, x1, v1 = func1(time, h)
        _, x2, v2 = func2(time, h)
        ex.append(np.mean(np.abs(x1-x2)))
        ev.append(np.mean(np.abs(v1-v2)))
    return ex, ev

def error_plot(func1, func2, title):
    h_list = np.logspace(-3, -1, 20)
    ex, ev = error(func1, func2, 5.0, h_list)
    pl.loglog(h_list, ex, lw=2, label="error x of %s" % title)
    pl.loglog(h_list, ev, lw=2, label="error v of %s" % title)

error_plot(solve_by_odeint, solve_by_euler, "euler")
error_plot(solve_by_odeint, solve_by_midpoint, "midpoint")
error_plot(solve_by_odeint, solve_by_rk4, "rk4")
pl.rcParams["legend.fontsize"] = "small"
pl.legend(loc="best")
pl.show()

下图是程序的输出结果,图中X轴和Y轴都是对数坐标,因此在图中误差和h之间成线性关系。而直线的斜率表示算法的阶数。由于odeint()也并非完全精确,其缺省的精度设置为1.49012e-8,因此对于RK4算法,图中h<0.01部分的误差不再与h成线性关系。

上一篇:制作IPython notebook的便携环境
下一篇:Cython中的Memoryview切片