itertools.groupbyでグルーピング

ライブラリ

itertools、いいですよね

みなさん、itertoolsは使っておりますでしょうか。大変便利なPython標準ライブラリなのですが、使い方に少しコツがいるので、私はなかなかうまく使いこなせていない感があります。このライブラリにクセがあるというわけではなく、アルゴリズム自体の難しさに起因しているんだとは思いますが。

そこで今回は、このitertoolsのgroupby関数を使って遊んでみようと思います。SQLにおけるROW_NUMBER関数のような使い方をしてみます。

itertools --- 効率的なループ用のイテレータ生成関数群
このモジュールは イテレータ を構築する部品を実装しています。プログラム言語 APL, Haskell, SML からアイデアを得ていますが、 Python に適した形に修正されています。 このモジュールは、高速でメモリ効率に優れ、単独でも...

今回のテーマ

ある学校に、3つの学年、それぞれに3つのクラスがあるとします。各クラスには生徒が10人います。各生徒の成績データが、表形式で与えられるとします。

全生徒について、クラス内順位を求めてみましょう。

コード

from itertools import groupby
from random import randint
from pprint import pprint

# 成績データを作成
scores = [
    {
        "学年": grade,
        "組": clas,
        "出席番号": student_num,
        "成績": randint(0, 100)
    }
    for grade in range(1, 4)
    for clas in range(1, 4)
    for student_num in range(1, 11)
]

# 学年、組、成績(降順)でソート
sorted_scores = sorted(
    scores,
    key=lambda row: (row["学年"], row["組"], -row["成績"])
)

# クラス内順位を追加する
result = [
    {
        **student_info,
        "クラス内順位": class_rank
    }

    # 学年、組でグルーピング
    for _, group in groupby(
        sorted_scores,
        key=lambda row: (row["学年"], row["組"])
    )

    # クラス内順位
    for class_rank, student_info in enumerate(group, start=1)
]

pprint(result)

成績データは、dictのlistで表現しています。[{生徒1のデータ}, {生徒2のデータ}, …]という感じです。csvなどからデータをインポートするときは、この形式にすることが多い気がします。pandasというライブラリを使わないのであれば。

成績をソートするときは、sorted関数を使います。pythonでは、tuple同士を比較するとき、tupleの第n番目が同じ場合は第n+1番目を比較する、という特徴を利用します。比較関数(key)で第1ソートキー〜第3ソートキーを指定しています。成績はマイナスにして降順にしています。

クラス内順位を求めるときは、リスト内包表記を使用します。そして、class_groupをenumerate関数で順番付けします。このとき、class_groupは成績順にソートされているので、class_rankはその生徒の順位となります。

実行結果

[{'クラス内順位': 1, '出席番号': 5, '学年': 1, '成績': 94, '組': 1},
 {'クラス内順位': 2, '出席番号': 8, '学年': 1, '成績': 71, '組': 1},
 {'クラス内順位': 3, '出席番号': 6, '学年': 1, '成績': 51, '組': 1},
 {'クラス内順位': 4, '出席番号': 4, '学年': 1, '成績': 36, '組': 1},
 {'クラス内順位': 5, '出席番号': 3, '学年': 1, '成績': 32, '組': 1},
 {'クラス内順位': 6, '出席番号': 7, '学年': 1, '成績': 29, '組': 1},
 {'クラス内順位': 7, '出席番号': 1, '学年': 1, '成績': 15, '組': 1},
 {'クラス内順位': 8, '出席番号': 9, '学年': 1, '成績': 13, '組': 1},
 {'クラス内順位': 9, '出席番号': 10, '学年': 1, '成績': 4, '組': 1},
 {'クラス内順位': 10, '出席番号': 2, '学年': 1, '成績': 0, '組': 1},
 {'クラス内順位': 1, '出席番号': 2, '学年': 1, '成績': 95, '組': 2},
 {'クラス内順位': 2, '出席番号': 5, '学年': 1, '成績': 71, '組': 2},
(省略)
 {'クラス内順位': 8, '出席番号': 6, '学年': 3, '成績': 53, '組': 3},
 {'クラス内順位': 9, '出席番号': 7, '学年': 3, '成績': 47, '組': 3},
 {'クラス内順位': 10, '出席番号': 4, '学年': 3, '成績': 11, '組': 3}]

クラス内順位が追加されていますね。

ROW_NUMBERしたい

冒頭にも出てきたSQLのROW_NUMBER関数は、

  • パーティションカラム
  • ソートカラム
  • 各ソートキーの昇順or降順

を指定し、連番をふることができます。これをPythonでもやりたいですね。

やってみました。

def row_number(
        data: list[dict],
        partitions: list,
        sortkeys: list,
        reverse: list[bool],
        start: int
    ):

    # ソートする優先順位に並べ替え
    reverse.reverse()
    sortkeys.reverse()

    # dataをpartitionsでグルーピング
    data.sort(key=lambda row: [row[partition] for partition in partitions])
    groups = [
        list(group)
        for _, group in groupby(
            data,
            key=lambda row: [row[partition] for partition in partitions]
        )
    ]

    # 各グループをソート
    for group in groups:
        for sortkey, rv in zip(sortkeys, reverse):
            group.sort(key=lambda row: row[sortkey], reverse=rv)

    # 各グループにインデックスを付与
    groups = [list(enumerate(group, start=start)) for group in groups]

    # グループ結合
    result = sum(groups, [])

    # sortkeysによりソート
    for sortkey, rv in zip(sortkeys, reverse):
        result.sort(key=lambda row: row[1][sortkey], reverse=rv)

    return result
  1. データをグルーピング
  2. 各グループに番号を振る
  3. データをソート

の順番で処理しています。

昇順、降順を混ぜてソートする必要があるのですが、それについては以下の記事を参考にさせていただきました。

先程のクラス内順位を求めるときは、以下のように指定します。

scores = row_number(
    data=scores,
    partitions=["学年", "組"],
    sortkeys=["学年", "組", "成績"],
    reverse=[False, False, True],
    start=1
)
pprint(scores)

結果は以下のようになります。

[(1, {'出席番号': 5, '学年': 1, '成績': 92, '組': 1}),
 (2, {'出席番号': 8, '学年': 1, '成績': 77, '組': 1}),
 (3, {'出席番号': 3, '学年': 1, '成績': 68, '組': 1}),
 (4, {'出席番号': 9, '学年': 1, '成績': 43, '組': 1}),
 (5, {'出席番号': 6, '学年': 1, '成績': 38, '組': 1}),
 (6, {'出席番号': 4, '学年': 1, '成績': 31, '組': 1}),
 (7, {'出席番号': 7, '学年': 1, '成績': 24, '組': 1}),
 (8, {'出席番号': 10, '学年': 1, '成績': 20, '組': 1}),
 (9, {'出席番号': 1, '学年': 1, '成績': 18, '組': 1}),
 (10, {'出席番号': 2, '学年': 1, '成績': 5, '組': 1}),
 (1, {'出席番号': 4, '学年': 1, '成績': 92, '組': 2}),
 (2, {'出席番号': 1, '学年': 1, '成績': 67, '組': 2}),
(省略)
 (8, {'出席番号': 7, '学年': 3, '成績': 16, '組': 3}),
 (9, {'出席番号': 6, '学年': 3, '成績': 3, '組': 3}),
 (10, {'出席番号': 10, '学年': 3, '成績': 0, '組': 3})]

計算時間、メモリ効率はあまり考慮していません。

まとめ

itertools、いいですね。

タイトルとURLをコピーしました