提高numpy.linalg.det()的运算速度

8142阅读 0评论2012-11-17 HyryStudio
分类:Python/Ruby

在上有人问道是否能对如下的循环进行提速。

import numpy as np

N = 1000
data = np.random.rand(N, 10, 10)
dm = np.zeros(N)
for i in xrange(N):
    dm[i] = np.linalg.det(data[i])

即调用N次det()计算N个相同大小的矩阵的行列式。NumPy给人的印象是它包装了大量高速运算的Fortran库,因此除非使用编译语言,很难再对其进行加速。然而实际上NumPy除了对Fortran库进行包装之外,它还需要做许多额外的工作,我们可以想办法提高这些额外工作的效率。

NumPy中的det()代码

首先下面是numpy.linalg.det()相关的代码:

def slogdet(a):
    a = asarray(a)
    _assertRank2(a)
    _assertSquareness(a)
    t, result_t = _commonType(a)
    a = _fastCopyAndTranspose(t, a)
    a = _to_native_byte_order(a)
    n = a.shape[0]
    if isComplexType(t):
        lapack_routine = lapack_lite.zgetrf
    else:
        lapack_routine = lapack_lite.dgetrf
    pivots = zeros((n,), fortran_int)
    results = lapack_routine(n, n, a, n, pivots, 0)
    info = results['info']
    if (info < 0):
        raise TypeError, "Illegal input to Fortran routine"
    elif (info > 0):
        return (t(0.0), _realType(t)(-Inf))
    sign = 1. - 2. * (add.reduce(pivots != arange(1, n + 1)) % 2)
    d = diagonal(a)
    absd = absolute(d)
    sign *= multiply.reduce(d / absd)
    log(absd, absd)
    logdet = add.reduce(absd, axis=-1)
    return sign, logdet

def det(a):
    sign, logdet = slogdet(a)
    return sign * exp(logdet)

由这段代码可知,对于实数矩阵会调用Fortran库中的lapack_lite.dgetrf()。在这句关键的Fortran函数调用之前,NumPy对输入数组进行了许多检测和转换工作。而在调用之后,还通过一些其它函数对输出进行运算。

我们可以重新编写这段代码,将循环集中在调用lapack_lite.dgetrf()之上。尽量删除掉对输入数据的检测和转换工作,而对于输出结果我们希望能使用NumPy的广播功能代替循环计算。

不过在着手做这些事情之前,让我们先进行一次Profiling,看看能否真的提高计算速度。

Profiling

在IPython的notebook中,我们可以使用%%prun命令对代码进行Profing:

%%prun
import numpy as np

N = 5000
data = np.random.rand(N, 10, 10)
dm = np.zeros(N)
for i in xrange(N):
    dm[i] = np.linalg.det(data[i])

对上面的代码进行Profiling的结果如下:

165004 function calls in 1.581 seconds

Ordered by: internal time

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
 5000    0.551    0.000    1.432    0.000 linalg.py:1560(slogdet)
15000    0.130    0.000    0.130    0.000 {method 'reduce' of 'numpy.ufunc' objects}
 5000    0.078    0.000    1.510    0.000 linalg.py:1642(det)
 5000    0.068    0.000    0.068    0.000 {numpy.linalg.lapack_lite.dgetrf}
 5000    0.068    0.000    0.068    0.000 {numpy.core.multiarray._fastCopyAndTranspose}
 5000    0.060    0.000    0.130    0.000 linalg.py:99(_commonType)
 5000    0.052    0.000    0.052    0.000 {method 'diagonal' of 'numpy.ndarray' objects}
10000    0.051    0.000    0.097    0.000 numeric.py:167(asarray)
 5000    0.047    0.000    0.123    0.000 linalg.py:139(_fastCopyAndTranspose)
10000    0.046    0.000    0.046    0.000 {numpy.core.multiarray.array}
    1    0.040    0.040    1.581    1.581 :2()
 5000    0.039    0.000    0.057    0.000 linalg.py:127(_to_native_byte_order)
 5000    0.038    0.000    0.142    0.000 fromnumeric.py:902(diagonal)
10000    0.038    0.000    0.058    0.000 linalg.py:71(isComplexType)
 5000    0.034    0.000    0.056    0.000 linalg.py:157(_assertSquareness)
 5000    0.034    0.000    0.034    0.000 {numpy.core.multiarray.arange}
15000    0.034    0.000    0.034    0.000 {issubclass}
    1    0.031    0.031    0.031    0.031 {method 'rand' of 'mtrand.RandomState' objects}
 5000    0.031    0.000    0.040    0.000 linalg.py:151(_assertRank2)
 5001    0.025    0.000    0.025    0.000 {numpy.core.multiarray.zeros}
15000    0.025    0.000    0.025    0.000 {len}
 5000    0.020    0.000    0.029    0.000 linalg.py:84(_realType)
 5000    0.012    0.000    0.012    0.000 {max}
 5000    0.010    0.000    0.010    0.000 {method 'append' of 'list' objects}
 5000    0.010    0.000    0.010    0.000 {min}
 5000    0.009    0.000    0.009    0.000 {method 'get' of 'dict' objects}
    1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

可以看到Fortran库函数lapack_lite.dgetrf()并不是最耗时的,如果我们将循环集中在它上面,尽量减少其它函数的调用次数,能提高将近10倍的运算速度。

编写高速运算的代码

我们将NumPy中的det()slogdet()的运算放在一起,并对程序做了如下改动:

import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite

def dets_fast(a):
    m = a.shape[0]
    n = a.shape[1]
    lapack_routine = lapack_lite.dgetrf
    pivots = np.zeros((m, n), intc)
    flags = np.arange(1, n + 1).reshape(1, -1)
    for i in xrange(m):
        tmp = a[i]
        lapack_routine(n, n, tmp, n, pivots[i], 0)
    sign = 1. - 2. * (np.add.reduce(pivots != flags, axis=1) % 2)
    idx = np.arange(n)
    d = a[:, idx, idx]
    absd = np.absolute(d)
    sign *= np.multiply.reduce(d / absd, axis=1)
    np.log(absd, absd)
    logdet = np.add.reduce(absd, axis=-1)
    return sign * np.exp(logdet)

下面是直接采用循环调用linalg.det()的代码:

import numpy as np
from numpy.core import intc
from numpy.linalg import lapack_lite

def dets(a):
    length = a.shape[0]
    dm = np.zeros(length)
    for i in xrange(length):
        dm[i] = np.linalg.det(M[i])
    return dm

首先检测计算结果是否正确,由于dets_fast()中没有对原始矩阵进行转置,因此运算结果和dets()有微小差别,因此使用numpy.allclose()比较二者的结果:

N = 1000
M = np.random.rand(N*10*10).reshape(N, 10, 10)
print np.allclose(dets(M), dets_fast(M.copy()))
True

下面比较运算速度,可以看出运算速度有10多倍的提升。由于dets_fast()会改变输入数组,因此我们将M复制一份再传递给它。

%timeit dets(M)
%timeit dets_fast(M.copy())
1 loops, best of 3: 173 ms per loop
100 loops, best of 3: 14.1 ms per loop
上一篇:继承对属性访问速度的影响
下一篇:IPython Notebook简介1