numpy.take_along_axis

多次元アレイの操作時に,任意の軸に沿って値を選択したい時に非常に便利な関数.
numpy 1.15.0からの関数で,Colaboratoryでは標準でnumpy 1.14.6なのと,
基本的に2D,せいぜい3D程度までしか扱わないので普段は使わないんだけど,
シンプルに処理できるだけじゃなくて時間効率も良い.

2Dアレイに於ける行ごとの要素の比較(numpy.isinな処理の2Dへの拡張)

 
 

いくつか,サンプルケースをみていくと.

【python】numpyで多次元配列のargsortと値の取り出し – 静かなる名辞

2次元程度であれば,インデックス操作も然程複雑ではないし,効率もそれ程悪くはない(時間効率に於いては大体のケースでほぼ同じ).また,Cライクに処理する方法もあって(Integer array indexing),おそらく時間効率だけ考えればその方が速い(もっと高次元なケースでは分からないけど2D程度なら).

import numpy as np


a = np.array([[2,0,1,8],[1,1,0,7]])
r, c = a.shape
print(a[np.arange(len(a))[:, None], a.argsort(1)])
%timeit a[np.arange(len(a))[:, None], a.argsort(1)]
print(a.ravel()[a.argsort(1)+np.arange(r)[:, None]*c])
%timeit r, c = a.shape;a.ravel()[a.argsort(1)+np.arange(r)[:, None]*c]
print(np.take_along_axis(a, a.argsort(1), 1))
%timeit np.take_along_axis(a, a.argsort(1), 1)

a = np.random.randint(1, 10000, size=(1000, 100))
res1 = a[np.arange(len(a))[:, None], a.argsort(1)]
res2 = np.take_along_axis(a, a.argsort(1), 1)
np.testing.assert_allclose(res1, res2)
%timeit res1 = a[np.arange(len(a))[:, None], a.argsort(1)]
%timeit r, c = a.shape;a.ravel()[a.argsort(1)+np.arange(r)[:, None]*c]
%timeit res2 = np.take_along_axis(a, a.argsort(1), 1)
[[0 1 2 8]
 [0 1 1 7]]
100000 loops, best of 3: 2.83 µs per loop
[[0 1 2 8]
 [0 1 1 7]]
100000 loops, best of 3: 4.69 µs per loop
[[0 1 2 8]
 [0 1 1 7]]
100000 loops, best of 3: 8.5 µs per loop

100 loops, best of 3: 6.36 ms per loop
100 loops, best of 3: 4.01 ms per loop
100 loops, best of 3: 6.43 ms per loop

Sort invariant for numpy.argsort with multiple dimensions – StackOverflow

2Dなので,軸を合わせるのは難しくない.

import numpy as np


A = np.array([[3,2,1],[4,0,6]])
r, c = A.shape
B = np.array([[3,1,4],[1,5,9]])

print(A[np.arange(len(A))[:, None], A.argsort(-1)])
%timeit A[np.arange(len(A))[:, None], A.argsort(-1)]
print(A.ravel()[A.argsort(1)+np.arange(r)[:, None]*c])
%timeit r, c = A.shape;A.ravel()[A.argsort(1)+np.arange(r)[:, None]*c]
print(np.take_along_axis(A, A.argsort(-1), -1))
%timeit np.take_along_axis(A, A.argsort(-1), -1)

print(B[np.arange(len(A))[:, None], A.argsort(-1)])
%timeit B[np.arange(len(A))[:, None], A.argsort(-1)]
print(B.ravel()[A.argsort(1)+np.arange(r)[:, None]*c])
%timeit r, c = A.shape;B.ravel()[A.argsort(1)+np.arange(r)[:, None]*c]
print(np.take_along_axis(B, A.argsort(-1), -1))
%timeit np.take_along_axis(B, A.argsort(-1), -1)

A = np.concatenate([A]*1000000)
B = np.concatenate([B]*1000000)
%timeit A[np.arange(len(A))[:, None], A.argsort(-1)]
%timeit r, c = A.shape;A.ravel()[A.argsort(1)+np.arange(r)[:, None]*c]
%timeit np.take_along_axis(A, A.argsort(-1), -1)

%timeit B[np.arange(len(A))[:, None], A.argsort(-1)]
%timeit r, c = A.shape;B.ravel()[A.argsort(1)+np.arange(r)[:, None]*c]
%timeit np.take_along_axis(B, A.argsort(-1), -1)
[[1 2 3]
 [0 4 6]]
100000 loops, best of 3: 2.76 µs per loop
[[1 2 3]
 [0 4 6]]
100000 loops, best of 3: 4.67 µs per loop
[[1 2 3]
 [0 4 6]]
100000 loops, best of 3: 8.65 µs per loop

[[4 1 3]
 [5 1 9]]
100000 loops, best of 3: 2.84 µs per loop
[[4 1 3]
 [5 1 9]]
100000 loops, best of 3: 4.78 µs per loop
[[4 1 3]
 [5 1 9]]
100000 loops, best of 3: 8.6 µs per loop

10 loops, best of 3: 135 ms per loop
10 loops, best of 3: 127 ms per loop
10 loops, best of 3: 136 ms per loop

10 loops, best of 3: 136 ms per loop
10 loops, best of 3: 127 ms per loop
10 loops, best of 3: 136 ms per loop

3D位のインデックス操作でもできないことも無いけど難しい,
という例を探したかったけどちょっと見当たらなかった…….

Sorting 4D numpy array but keeping one axis tied together – StackOverflow

4Dでもう訳が分からない.ノートに書き出しながら考えれば処理できなくもないけど,
この辺はもう素直にnumpy.take_along_axisのありがたみを噛み締めたい.

arr = np.random.randint(0,10,(3,3,2,2))
idx = arr[..., 0].argsort(0)
np.take_along_axis(arr, idx[..., None], 0)
array([[[[0, 5],
         [2, 7]],

        [[1, 3],
         [5, 8]],

        [[3, 1],
         [5, 3]]],


       [[[8, 1],
         [3, 3]],

        [[8, 7],
         [7, 1]],

        [[6, 5],
         [5, 8]]],


       [[[9, 1],
         [7, 9]],

        [[8, 0],
         [7, 9]],

        [[9, 9],
         [7, 3]]]])

Numpy select matrix specified by a matrix of indices, from multidimensional array – StackOverflow

これだけ複雑でも,numpy.take_along_axisなら5次元でも6次元でも直感的に操作できる.

カテゴリー: 未分類 パーマリンク

コメントを残す

以下に詳細を記入するか、アイコンをクリックしてログインしてください。

WordPress.com ロゴ

WordPress.com アカウントを使ってコメントしています。 ログアウト /  変更 )

Google フォト

Google アカウントを使ってコメントしています。 ログアウト /  変更 )

Twitter 画像

Twitter アカウントを使ってコメントしています。 ログアウト /  変更 )

Facebook の写真

Facebook アカウントを使ってコメントしています。 ログアウト /  変更 )

%s と連携中

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください