多重リストのフラット化あれこれ(with Python)

多重リストのフラット化あれこれ

(with Python)

『Effective Python』Item 8: リスト内包表記では2つ以上の式を避ける-何かを書き留める何か

を読んでいて,forループと内包表記って凄く相性が良くて,2,3重ループまでは普通に内包表記の方が可読性良い気がするなあなんて思いつつ(complexだけどcomplicatedでは無いよね?と),「多重化のレベルに依存しない平滑化」かーと思い,色々と調べてみた.

[1] pythonでflatten-Qiita

[2] Pythonでリストをflattenする方法まとめ-Soleil cou coupé

[2]の「抽象基底クラスを用いて変数がIterableであるかどうかを判断」というのが,凄く面白い(勉強になった).

簡単に纏めると,1段ネストの場合,「chain.from_iterable(from itertools import chain)」がリーズナブルチョイスになる.これは,スクレイピングの時とか,割と1段ネストは遭遇するのでよく使う.2段程度の場合,構造が明らかならlist(chain.from_iterable(chain.from_iterable(Multi_Lists)))でも良いし,内包表記でも良いと思う.ただ,どちらにせよ,データ構造が明らかじゃない場合,或いはリスト内にiterableではない要素が含まれる場合,それらの方法では「TypeError: ‘int’ object is not iterable

」になってしまう.そこで,[2]の話.iterableかどうかを判断し,再帰的に処理する事で,ネスト構造によらず,フラット化を実現できる.[2]の記事では,Python2系で書かれているけど,Python3系では「basestring」という抽象クラスは削除されているので,「str」に書き換え.

from itertools import chain
import collections

def flatten(nested_list):
    result = []
    for element in nested_list:
        if isinstance(element, collections.Iterable) and not isinstance(element, str):
            result.extend(flatten(element))
        else:
            result.append(element)
    return result

my_lists = [
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]],
]

my_lists2 = [
    [1,[2,3],[[4,[5,6]],7]]
]

print(list(flatten(my_lists)))
print(list(flatten(my_lists2)))

再帰処理は見栄えは良いけど,complicatedになりがちで,Pythonは言語哲学的にも,言語仕様的にも,complicatedな構造は駄目(苦手)なので,「再帰なしの実装」の方が良いんだろうけど,僕はwhileが好きではないので略.

そして,その上で[1]の

flatten=lambda i:[a for b in i for a in (flatten(b) if hasattr(b,'__iter__') else(b,))]

が凄い.最初,多重リストのフラット化を考えていた時,「リスト内包表記で再帰的な処理ってどうやるんだろう,うーん無理かな……」と悩んでいた時に,[2]をみて,そして[1]をみて,[2]の話は多分今後同様の問題については,自分の中で消化できたので,上手く適用していけるけど,これは多分まだ真似はできない.ただ,「成る程なあ」と.

ただ,それで終わってしまっては面白くないので,素人考えでもっと簡単にできて(どんな場合にも適用できて),それでいて十分に高速な方法は無いだろうか?と考えてみた.

目安(比較対象)として,

l = []
for i in range(500):
    l = [l, i]

に対して,

[1]の

%%timeit
flatten=lambda i:[a for b in i for a in (flatten(b) if hasattr(b,'__iter__') else(b,))]
flatten(l)

7.01 ms ± 276 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

[2]の

%%timeit
from itertools import chain
import collections

def flatten2(nested_list):
    result = []
    for element in nested_list:
        if isinstance(element, collections.Iterable) and not isinstance(element, str):
            result.extend(flatten(element))
        else:
            result.append(element)
    return result

flatten2(l)

7.05 ms ± 190 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

を基準に,まあ10ms程度で簡単な方法であれば良しとしよう.

まず,ぱっと思いつく方法.numpyを使ってフラット化を試みる.

import numpy as np

np.array(list).reshape(-1)

(ほぼ同じものにravel()

or

np.array(list).flatten()

ただ,やってみて気付いたけど,np.array化した時,ネスト構造の要素数が同じでなければ,要素がオブジェクトになってしまって,np.array内にそのままリストを持ってしまい,やりたい事が何もできなくなる.例えば,上記my_lists2やlみたいな(dtype=objectになる)データ構造だと,使えない.まあ,

[[[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]]みたいなデータの場合は,

list(chain.from_iterable(chain.from_iterable(chain.from_iterable(chain.from_iterable(a)))))

となる所,

a = np.array([[[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]])
a.reshape(-1)

と書ける程度には便利だろうか.

a = [[[[i]]] for i in range(500)]

でflatten, flatten2と速度比較すると,

PAGE_BREAK: PageBreak

%timeit flatten(a)
%timeit flatten2(a)

2.05 ms ± 43.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.31 ms ± 70.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
b = np.array(a)
b.reshape(-1)

445 µs ± 4.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

なので,numpyで処理できる(dtype=objectにならない)データ構造なら,numpyを使うという選択はありかもしれない.

dtype=objectについて,あれこれ.

How to flatten a numpy array of dtype object-StackOverflow

a = np.array(((np.array((1,2)), np.array((1,2,3))), (np.array((1,2)), np.array((1,2,3,4,5,6,7,8)))))
print(a)
print(a.reshape(-1))
print(np.hstack(a))
print(np.hstack(a.flat))

[[array([1, 2]) array([1, 2, 3])]

[array([1, 2]) array([1, 2, 3, 4, 5, 6, 7, 8])]]

[array([1, 2]) array([1, 2, 3]) array([1, 2])

array([1, 2, 3, 4, 5, 6, 7, 8])]

[array([1, 2]) array([1, 2, 3]) array([1, 2])

array([1, 2, 3, 4, 5, 6, 7, 8])]

[1 2 1 2 3 1 2 1 2 3 4 5 6 7 8]

np.hstack(np.hstack(a))でもできる.複雑そうにみえてそれぞれがarray型だから操作し易い.

余談:

my_lists2 = [
    [1,[2,3],[[4,[5,6]],7]]
]

np_a = np.array(my_lists2)

ValueError: setting an array element with a sequence.

だけど,

import pandas as pd

c = pd.DataFrame(my_lists2)
c.values

array([[1, [2, 3], [[4, [5, 6]], 7]]], dtype=object)

Pandas便利.

print(c.values)
print(c.values.flatten())
print(np.hstack(c.values.flatten()))

[[1 [2, 3] [[4, [5, 6]], 7]]]
[1 [2, 3] [[4, [5, 6]], 7]]
[1 2 3 [4, [5, 6]] 7]

何がしたいのか分からなくなってきた.取り敢えず,手軽さと速さを基準に主観的観点から纏める.

まとめ:

1.1段ネストのデータ構造(例:[[1, 2, 3], [4, 5, 6], [7, 8, 9], …])

from itertools import chain
list(chain.from_iterable(multi_list))

2.2段以上で同一要素数のネストを持つデータ構造(例:[[[[[1, 2], [3, 4], [5, 6], …]]]])

import numpy as np
list(np.array(multi_list).reshape(-1))

3.どんな場合にも適応できる素晴らしいコード([1]や[2]の話)

flatten = lambda i:[a for b in i for a in (flatten(b) if hasattr(b,'__iter__') else(b,))]
flatten(multi_list)

or

from itertools import chain
import collections

def flatten(nested_list):
    result = []
    for element in nested_list:
        if isinstance(element, collections.Iterable) and not isinstance(element, str):
            result.extend(flatten(element))
        else:
            result.append(element)
    return result

list(flatten(multi_list))
カテゴリー: 未分類 パーマリンク