NumPy高速化を目指して
Contents
背景
業務で計算速度高速化を目指している際に、線形代数ライブラリに関する言及をいくつか見つけたので、備忘録としてまとめておく
まとめ
線形代数ライブラリの概要
線形代数ライブラリとは、以下2つをパッケージングしたライブラリ
- BLAS (Basic Linear Algebra Subprograms):ベクトルと行列のかけ算や足し算を担う
- LAPACK (Linear Algebra PACKage):連立線型方程式や固有値方程式、特異値分解の解法を担う
BLASとLAPACKは数値計算を行うソフトウェアの大部分で使用されている
- Python (NumPy, SciPy)
- MATLAB
- R
- TensorFlow
- etc…
BLASで定義される演算には、Level 1 BLAS、Level 2 BLAS、Level 3 BLASという分類があり、より高いLevelのBLASを利用したほうが効率よく計算できる
- Level 1 BLAS:ベクトルとベクトルの演算
- Level 2 BLAS:行列とベクトルの演算
- Level 3 BLAS:行列と行列の演算
線形代数ライブラリは、アセンブリ言語で実装されているソフトウェアが多いので、使用するコンピュータのアーキテクチャに対応したライブラリを選ぶ必要がある
DGEMM (BLASに含まれる行列積計算ルーチン) は高性能比較のベンチマークによく用いられる
主要線形代数ライブラリ
BLAS/LAPACK
- Netlib公式参照実装
- DGEMM:理論最高性能の 10 % 程度
- パブリックドメインライセンスで自由に利用可能
ATLAS
- パラメータサーベイによって BLASを自動チューニングする実装
- DGEMM:理論最高性能の 80 % 程度
- BSDライセンスのオープンソース
MKL
- Intel製の数値計算ライブラリ
- DGEMM:理論最高性能の 96~97 % 程度
- 有料
NumPy高速化のTips
NumPyの関数を出来るだけ使用 (どうしてもforループから逃げられない場合のみNumba等を検討 (参考記事))
以下は上記スクリプトの結果だが、この結果を見る限りだと、計算速度という観点では、jitとnumpyのパフォーマンス差は小さそうだと言える
------------ array_size = 100,000 ------------ 実行された関数: simple_sum 処理時間: 0.000秒 計算結果: 49988.75 実行された関数: jit_sum 処理時間: 0.073秒 計算結果: 49988.75 実行された関数: numpy_sum 処理時間: 0.000秒 計算結果: 49988.75 ------------ array_size = 1,000,000 ------------ 実行された関数: simple_sum 処理時間: 0.103秒 計算結果: 499994.21 実行された関数: jit_sum 処理時間: 0.001秒 計算結果: 499994.21 実行された関数: numpy_sum 処理時間: 0.001秒 計算結果: 499994.21 ------------ array_size = 10,000,000 ------------ 実行された関数: simple_sum 処理時間: 0.926秒 計算結果: 4999245.42 実行された関数: jit_sum 処理時間: 0.012秒 計算結果: 4999245.42 実行された関数: numpy_sum 処理時間: 0.007秒 計算結果: 4999245.42 ------------ array_size = 100,000,000 ------------ 実行された関数: simple_sum 処理時間: 9.227秒 計算結果: 50003393.92 実行された関数: jit_sum 処理時間: 0.125秒 計算結果: 50003393.92 実行された関数: numpy_sum 処理時間: 0.147秒 計算結果: 50003393.92
行列の演算は、NumPyの中で利用されているBLASを意識しないと、本来のパフォーマンスを出すことができない (≒ forループの排除)
インプレイス演算 (a += b) で ミュータブルオブジェクトを更新 (説明記事)