broadcast系男子

櫻井研では、以前、「list内包表記系男子」という言葉が流行りましたね。
しかし、list内包表記に書き換えることよる実行時間削減効果は、他の高速化テクニックに比べてそこまで効果が大きいものではありません。

ところで、Pythonを使って科学技術計算をする人々の間ではこんな言葉が流行っているのをご存知ですか?
broadcast系男子
broadcastとは、numpyで使われる配列の演算の方法です。
さっそく具体例を見てみましょう。

例1
>>> import numpy as np 
>>> a=np.arange(5)
>>> a
array([0, 1, 2, 3, 4])
>>> a*5
array([0, 5, 10, 15, 20])

「え、普通じゃん」と思う方もいらっしゃることでしょう。
(初歩的な質問ですが、aがnumpyの配列ではなくPythonのリストである場合、a*5はどうなるかわかりますよね?)

それでは別の具体例を見てみましょう。

例2
>>> a=np.array([[10, 20], [30, 40]])
>>> b=np.array([1, 2])
>>> a+b
array([[11, 22],
       [31, 42]])

aとbの配列の次元(以下shapeと呼ぶ)は一致しませんが、足し算ができていますね。
これがbroadcastするということなのです。

次に少しだけ違う例を見てみましょう。

例3
>>> a=np.array([[10, 20], [30, 40]])
>>> b=np.array([[1], [2]])
>>> a+b
array([[11, 21],
       [32, 42]])

例2ではb.shapeは(2,)でしたが、例3ではa.shapeは(2, 1)になっていますよね。
小さな違いが大きな差を生むのです。


実は、broadcastを使うと、大きな行列(あるいはテンソル)を少ない計算量で生成することができることがあるのです。

問題

# この配列を生成
[[0,   4,   8,   12,  16,  20,  24,  28,  32,  36],
 [1,   5,   9,   13,  17,  21,  25,  29,  33,  37],
 [2,   6,   10,  14,  18,  22,  26,  30,  34,  38],
 [3,   7,   11,  15,  19,  23,  27,  31,  35,  39],
 [40,  44,  48,  52,  56,  60,  64,  68,  72,  76],
 [41,  45,  49,  53,  57,  61,  65,  69,  73,  77],
 [42,  46,  50,  54,  58,  62,  66,  70,  74,  78],
 [43,  47,  51,  55,  59,  63,  67,  71,  75,  79],
 [80,  84,  88,  92,  96,  100, 104, 108, 112, 116],
 [81,  85,  89,  93,  97,  101, 105, 109, 113, 117],
 [82,  86,  90,  94,  98,  102, 106, 110, 114, 118],
 [83,  87,  91,  95,  99,  103, 107, 111, 115, 119],
 [120, 124, 128, 132, 136, 140, 144, 148, 152, 156],
 [121, 125, 129, 133, 137, 141, 145, 149, 153, 157],
 [122, 126, 130, 134, 138, 142, 146, 150, 154, 158],
 [123, 127, 131, 135, 139, 143, 147, 151, 155, 159],
 [160, 164, 168, 172, 176, 180, 184, 188, 192, 196],
 [161, 165, 169, 173, 177, 181, 185, 189, 193, 197],
 [162, 166, 170, 174, 178, 182, 186, 190, 194, 198],
 [163, 167, 171, 175, 179, 183, 187, 191, 195, 199]]

解答例

ここでは4つの解法を用意してみました。
それでは早速ご覧ください。

import numpy as np


# リスト内包表記を使わない
def func1(step=4, dim=10, n=200, length=5):
    indices1 = []
    for l in range(length):
        for s in range(step):
            index = range(l * step * dim + s, (l + 1) * step * dim + s, step)
            indices1.append(index)
    return(indices1)


# リスト内包表記を使う
def func2(step=4, dim=10, n=200, length=5):
    indices2 = [range(l * step * dim + s, (l + 1) * step * dim + s, step)
                for l in range(length) for s in range(step)]
    return(indices2)


# ブロードキャストを使う その1
def func3(step=4, dim=10, n=200, length=5):
    l = np.arange(0, n, step * dim)
    ll = l.reshape(length, 1)
    s = np.arange(step)
    ss = s.reshape(1, step)
    ls = (ll + ss).reshape(length * step, 1)  # ここでブロードキャスト
    i = np.arange(0, dim * step, step)
    ii = x.reshape(1, dim)
    return(ls + ii)                   # ここでもう一度ブロードキャスト


# ブロードキャストを使う その2
def func4(step=4, dim=10, n=200, length=5):
    x = np.arange(0, dim * step, step)
    y = np.arange(0, n, step * dim)
    z = np.arange(step)
    xx = x.reshape(1, 1, dim)
    yy = y.reshape(length, 1, 1)
    zz = z.reshape(1, step, 1)
    indices = xx + yy + zz	              # ここでブロードキャスト
    return(indices.reshape(n / dim, dim))

結果

今回は、jupyter notebook のマジックコマンドの1つである timeitを用いて時間を測定しました。

%timeit func1()
10000 loops, best of 3: 71.4 µs per loop

%timeit func2()
10000 loops, best of 3: 70.7 µs per loop

%timeit func3()
10000 loops, best of 3: 46.6 µs per loop

%timeit func4()
10000 loops, best of 3: 44.5 µs per loop

結構速くなっていますね。
実は、この高速化は行列が大きいほど効果が大きいのです。
step=40, dim=100, n=200000, length=50にして(要素数が1000倍になった)実行時間を測定してみましょう。

%timeit func1(40,100,200000,50)
100 loops, best of 3: 13 ms per loop

%timeit func2(40,100,200000,50)
100 loops, best of 3: 10.1 ms per loop

%timeit func3(40,100,200000,50)
1000 loops, best of 3: 783 µs per loop

%timeit func4(40,100,200000,50)
1000 loops, best of 3: 788 µs per loop

10倍以上高速化できましたね。

結論

broadcastを使うと良い。
broadcast系男子がいまアツい。

質疑応答

Q. めんどくさいんだけど?
A. 慣れれば簡単です。

Q. numpyを使ったから速いってだけなんじゃないの?
A. numpyをテキトーに使っても、2倍程度しか速くなりませんでした。
broadcastと上手く組み合わせることにより10倍以上高速化することができます。

Q. 解答例が意味分かんないんだけど?
A. 各変数に何が格納されているかをぜひ自分で確かめてみてください。

Q. 足し算でしか使えないの?
A. 引き算、掛け算、割り算、その他numpyが対応する二項演算でも使うことができます。

Q. numpy特有のテクニックなの?
A. 若干仕様が異なりますが、Mathematicaでも使えます。