numpy.arrayの要素を任意に置き換える

How to replace a list of values in a numpy array? – StackOverflow

まずは愚直に.

import numpy as np

numbers = np.arange(0, 1000000, dtype=np.int)
np.random.shuffle(numbers)
prob_n = np.random.randint(100, size=10000, dtype=np.int)  # table, night_stand, plant
alt_n = np.random.randint(100, size=10000, dtype=np.int)  # desk, dresser, flower_pot
%%timeit
for x, y in zip(prob_n, alt_n):
    numbers[numbers==x] = y

14.8 s ± 655 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

大体,O(n)でループが律速になっている(と思う).
100万の要素を置き換えようと思うと約30分かかる事に…….

この問題については,ベクトル化できるイメージが湧かないので,
numbaやcythonを試してみて,どの程度高速化できるか.
あまり期待できなければ,リストに置き換えた方が良いんじゃないか.

from numba import jit

@jit
def func(numbers, prob_n, alt_n):
    for x, y in zip(prob_n, alt_n):
        numbers[numbers==x] = y
    return numbers
%%timeit
func(numbers, prob_n, alt_n)

12.6 s ± 509 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%load_ext Cython
%%cython

def func2(numbers, prob_n, alt_n):
    for x, y in zip(prob_n, alt_n):
        numbers[numbers==x] = y
    return numbers
%%timeit
func2(numbers, prob_n, alt_n)

15.4 s ± 465 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%cython
cimport numpy as np
from numpy cimport ndarray

ctypedef np.int_t DTYPE_t

def func2(np.ndarray[DTYPE_t, ndim=1] numbers, np.ndarray[DTYPE_t, ndim=1] prob_n, np.ndarray[DTYPE_t, ndim=1] alt_n):
    cdef int x, y
    for x, y in zip(prob_n, alt_n):
        numbers[numbers==x] = y
    return numbers
%%timeit
func2(numbers, prob_n, alt_n)

14.9 s ± 295 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

def func3(numbers, prob_n, alt_n):
    return [[y if e == x else x for e in numbers] for x, y in zip(prob_n, alt_n)]
%%timeit
func3(numbers.tolist(), prob_n.tolist(), alt_n.tolist())

遅すぎて強制終了…… 
 
 

Cythonを使いこなせる様になりたいね.
 
 
 

追記:

素面で見直して,回答の素晴らしさに気付く.これも,indexerを用意してやれば物凄く効率的.
(ただ,回答者のコードは間違っているけど)

import numpy as np

numbers = np.arange(0, 1000000)
np.random.shuffle(numbers)
prob_n = np.random.randint(100, size=10000)
alt_n = np.random.randint(100, size=10000)

n_min, n_max = numbers.min(), numbers.max()
replacer = np.arange(n_min, n_max + 1)
mask =(n_min <= prob_n) & (prob_n <= n_max)
replacer[prob_n[mask] - n_min] = alt_n[mask]
numbers2 = replacer[numbers - n_min]
%%timeit
n_min, n_max = numbers.min(), numbers.max()
replacer = np.arange(n_min, n_max + 1)
mask =(n_min <= prob_n) & (prob_n <= n_max)
replacer[prob_n[mask] - n_min] = alt_n[mask]
numbers2 = replacer[numbers - n_min]

35 ms ± 3.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

from numba import jit

@jit
def fanc(numbers, prob_n, alt_n):
    n_min, n_max = numbers.min(), numbers.max()
    replacer = np.arange(n_min, n_max + 1)
    mask =(n_min <= prob_n) & (prob_n <= n_max)
    replacer[prob_n[mask] - n_min] = alt_n[mask]
    return replacer[numbers - n_min]


fanc(numbers, prob_n, alt_n)
%timeit fanc(numbers, prob_n, alt_n)

25.4 ms ± 2.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

カテゴリー: 未分類 パーマリンク