Python で Go のチャンネルを実装してみる

Go 言語では,並行処理プリミティブとして channel が定義されていて,並行処理の中核をなします. Python でも Go-like に並行処理を書きたいなと思ったので実装してみよう!というのが今回のお題です.

調査編

まずはどう実現しようかと調査をしました. すると,Go のチャンネルによく似た機能を持つ Queue というものが asyncio ライブラリには存在することがわかりました. Queue では, Queue の 最大容量 ( maxsize ) を指定したり,Queue の出し入れ ( get / put ) を awaitable に実行できます.

これを使えばええやん!というところなのですが,Go のチャンネルに比べると以下の機能が不足しています.

  1. チャンネルの close 機能
  2. for 文での処理機能

また, close があるなら Python ならコンテキストマネージャー(withブロック)を使いたいだろう!ということで,これにコンテキストマネージャー機能を加えた,以下の機能を Queue をベースに実装することにしました.

  1. チャンネルの close 機能
  2. for 文での処理機能
  3. コンテキストマネージャー機能

実装編

ということで実装しました.

from asyncio import Queue
from typing import Generic, TypeVar, Tuple

T = TypeVar('T')

class ChannelClosed(Exception):
    pass


class Channel(Generic[T]):
    def __init__(self, buffer: int=1):
        self._closed = False
        self._queue = Queue(maxsize=buffer)

    async def get(self) -> Tuple[T, bool]:
        if self._closed:
            raise ChannelClosed('channel already closed')
        return await self._queue.get()

    async def put(self, v: T) -> None:
        if self._closed:
            raise ChannelClosed('channel already closed')
        await self._queue.put((v, True))

    async def close(self) -> None:
        if self._closed:
            raise ChannelClosed('channel already closed')
        await self._queue.put((None, False))
        self._closed = True

    def __aiter__(self):
        return self

    async def __anext__(self) -> T:
        try:
            v, ok = await self.get()
        except ChannelClosed:
            raise StopAsyncIteration
        if not ok:
            raise StopAsyncIteration
        return v

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.close()

以下にそれぞれの機能をどのように実装したか簡単に解説します.

1. チャンネルの close 機能

close メソッドに対応します.

close されたときは, _closed フラグを True にしています.そうすることで以後 getput が呼び出されてもエラーを吐き出すようにしています.

また, Queue に終了通知 ( (None, False) ) を送信しています.これは,for 文でイテレーションするときに終了条件に使っていたり,ユーザもチャンネルのクローズを検知できるようにしています. Go のチャンネルで v, ok := <- ch のようにできることを意識しています.

2. for 文での処理機能

__aiter____anext__ メソッドに対応します.Python では,これらのメソッドを実装することによって,非同期イテレーション ( async for ) が使えるようになります.つまり,チャンネルを以下のように使えるようになります.

async for v in ch:
    # some process with v

__anext__ メソッド内では単純に get を呼び出していて,終了は 1. で述べたように終了通知で判断しています. async for では, StopAsyncIteration を raise することによって,イテレーションの終了を通知します.

3. コンテキストマネージャー機能

__aenter____aexit__ メソッドに対応します.Python では,これらのメソッドを実装することによって,非同期 with ブロック ( async with ) が使えるようになります.つまり,チャンネルを以下のように使えるようになります.( with ブロックを抜けるとチャンネルがクローズされます.)

async with Channel() as ch:
    # some process with ch

__aexit__ メソッド内では単純に close を呼び出しているだけです.


ということで,まぁ実装はできたわけですが,もう少し Go-like に書きたい!

今だと,チャンネルを受信したり送信したりするときに,以下のように書く必要があります.

# 送信
await ch.put(0)
# 受信
v = await ch.get()

一方 Go だと以下のように記号で書けて美しい.

// 送信
ch <- 0
// 受信
v := <- ch

Python だと <- を扱う方法はなさそうだけど, << なら扱えて,オーバーライドできる!というのを利用して,以下のように扱えるようにしたい!...

# 送信
await (ch << 0)
# 受信
v = await (<< ch)

けど, (<<ch) は明らかにおかしいので,仕方なく (_ << ch) とできるように _ も実装します.

追加実装編

送信

# 送信
await (ch << 0)

Python では,<<__lshift__ という関数に対応するため,これを実現するためには, Channel の __lshift__ という関数を以下のようにオーバーライドすれば良いです.

    async def __lshift__(self, v: T) -> None:
        await self.put(v)

受信

# 受信
v = await (_ << ch)

これは少し難しいですが, Channel から値を get してそれを単純に返すように _ を実装すれば良さそう!ということで,以下のように実装しました.

class _Mediator:
    async def __lshift__(self, ch: Channel[T]) -> Tuple[T, bool]:
        return await ch.get()

_ = _Mediator()

これで,

# 送信
await (ch << 0)
# 受信
v = await (_ << ch)

のように書くことができるようになりました!!

まとめ

Python で Go のチャンネルを実装してみました.割と Go-like に書くことができたのではないでしょうか.

コードをまとめると以下のようになります.

from asyncio import Queue
from typing import Generic, TypeVar, Tuple

T = TypeVar('T')

class ChannelClosed(Exception):
    pass


class Channel(Generic[T]):
    def __init__(self, buffer: int=1):
        self._closed = False
        self._queue = Queue(maxsize=buffer)

    async def get(self) -> Tuple[T, bool]:
        if self._closed:
            raise ChannelClosed('channel already closed')
        return await self._queue.get()

    async def put(self, v: T) -> None:
        if self._closed:
            raise ChannelClosed('channel already closed')
        await self._queue.put((v, True))

    async def __lshift__(self, v: T) -> None:
        await self.put(v)

    async def close(self) -> None:
        if self._closed:
            raise ChannelClosed('channel already closed')
        await self._queue.put((None, False))
        self._closed = True

    def __aiter__(self):
        return self

    async def __anext__(self) -> T:
        try:
            v, ok = await self.get()
        except ChannelClosed:
            raise StopAsyncIteration
        if not ok:
            raise StopAsyncIteration
        return v

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.close()


class _Mediator:
    async def __lshift__(self, ch: Channel[T]) -> Tuple[T, bool]:
        return await ch.get()

_ = _Mediator()

実は select / case も実装していたのですが,少し長くなってしまったので,またの機会にします.