Scalaでクラス変数を一時的に変更する

今回は,以下のようなシングルトンオブジェクトの value を特定のブロック内でのみ "hoge" として扱いたい,というようなケースを考えます.(そもそもそんなケースないようにするのが自然であるのは,それはそう)

object SharedString {
    var v = "default"

    def value = v

    def value_=(s: String) = {
        v = s
    }
}

第I章 ローンパターンを使う

はい.ローンパターンを使いましょう,ということで,下記のような関数が書けます.(最近このパターンをローンパターンと呼ぶことを知りました)

def withValue[T](s: String)(block: => T) {
    SharedString.value_=(s)
    try {
        block
    } finally {
        SharedString.value_=("default")
    }
}

すると,

withValue("hoge") {
    // この中では SharedString.value == "hoge"
}
// ここでは SharedString.value == "default"

のように書けますね...

と終わりたいところですが,このままだとマルチスレッドでは正常に動きません.

以下のようなコードを考えましょう.上のコードを2つのスレッドで実行しただけです. (あとで Future を扱うので, Thread で書いてます)

class Thread1 extends Thread {
    override def run() = {
        withValue("hoge") {  // ①
            Thread.sleep(1000)
            printf("Thread1: SharedString.value(%s) should be %s\n", SharedString.value, "hoge")  // ③
        }  // ④
    }
}
class Thread2 extends Thread {
    override def run() = {
        withValue("hoge") {  // ②
            Thread.sleep(2000)
            printf("Thread2: SharedString.value(%s) should be %s\n", SharedString.value, "hoge")  // ⑤
        }
    }
}
val th1 = new Thread1()
val th2 = new Thread2()

th1.start()
th2.start()
th1.join()
th2.join()

実行すると,

Thread1: SharedString.value(hoge) should be hoge
Thread2: SharedString.value(hoge) should be default

となります. Thread2 ではうまく "hoge" に置き換わっていないようです. これは,シングルトンオブジェクトの変数(所謂クラス変数)がスレッド間で共有されることに起因します.

どういうことかというと,わざとスレッドをスリープさせたりしているので,処理は①→②→...→⑤の順番に実行されますが, 問題は,④が⑤より先に実行されてしまうことです. withValue 関数では, teardown の処理として,SharedString.v"default" に戻してしまいます. そのため,④で SharedString.v"default" に書き換り,これはスレッド間で共有されているため,⑤で SharedString.v を参照しても "default" が得られてしまいます.

これは期待通りではないですね.どうにかマルチスレッドで動くようにしたい.

第II章 ThreadLocal を使う

調べてみると, ThreadLocal というものがあります.

このクラスはスレッド・ローカル変数を提供します。これらの変数は、getメソッドまたはsetメソッドを使ってアクセスするスレッドがそれぞれ独自に、変数の初期化されたコピーを持つという点で、通常の変数と異なります。通常、ThreadLocalインスタンスは、状態をスレッドに関連付けようとするクラスでのprivate staticフィールドです(ユーザーID、トランザクションIDなど)。

https://docs.oracle.com/javase/jp/8/docs/api/java/lang/ThreadLocal.html

詳しい説明は他に譲りますが,これを使うと,スレッド間で共有されている変数に関しても,スレッド固有に管理できるようになります.(内部的にはスレッドIDをキーに持つMapが保持されているイメージ)

これを踏まえて, SharedString を書き換えるとこのようになります.

object SharedString {
    var v = new ThreadLocal[String] {
        override def initialValue = "default"
    }

    def value = v.get

    def value_=(s: String) = {
        v.set(s)
    }
}

すると先ほどのコードは期待通り動くことがわかります.

でも実はこれでも終わりではありません. ネストしたスレッドだと動かないのです!!!なぜなら以下で言えば Thread2 では, withValue による初期化はされないから.

class Thread1 extends Thread {
    class Thread2 extends Thread {
        override def run(): = {
            printf("SharedString.value(%s) should be %s\n", SharedString.value, "hoge")
        }
    }
    override def run(): = {
        withValue("hoge") {
            printf("SharedString.value(%s) should be %s\n", SharedString.value, "hoge")
            val th2 = new Thread2()
            th2.start()
            th2.join()
        }
    }
}

val th1 = new Thread1()

th1.start()
th1.join()

第III章 InheritableThreadLocal / DynamicVariable を使う

ということで,親スレッドから子スレッドに ThreadLocal の値を引き継ぐ仕組みが必要そうです. なんて勿体ぶりましたが,実は普通に InheritableThreadLocal というものがあります.これを使うと,親スレッドから子スレッドに ThreadLocal の値を引き継げます.

InheritableThreadLocal を使うようにするとこうなります.(ただ ThreadLocalInheritableThreadLocal に置き換えるだけです)

object SharedString {
    var v = new InheritableThreadLocal[String] {
        override def initialValue = "default"
    }

    def value = v.get

    def value_=(s: String) = {
        v.set(s)
    }
}

Scalaでは InheritableThreadLocal を wrap した DynamicVariable というものが用意されているので,そちらを使うと以下のようにかけます.

object SharedString {
    var v = new DynamicVariable[String]("default")

    def value = v.value

    def value_=(s: String) = {
        v.value_=(s)
    }
}

今度こそ...,という感じですが,最後に関門があります.

スレッドプール です.

Scalaで並行処理を書くなら,意識する/しないには関わらず,スレッドプールは大体使うでしょう.

実はスレッドプールを考えると, InheritableThreadLocal を使っても動かないケースがあります.

InheritableThreadLocal は下記にある通り,スレッドが作成されたときに親スレッドから値が引き継がれる仕組みです.

... 子スレッドの作成時に、子は、親が値を保持する継承可能なスレッドローカル変数すべての初期値を受け取ります。

https://docs.oracle.com/javase/jp/6/api/java/lang/InheritableThreadLocal.html

しかし,スレッドプールはスレッドを再利用することがあります. 子スレッドを実行するときに,新しくスレッドが作られるのではなく,スレッドが再利用された場合,親スレッドから子スレッドへの ThreadLocal の値の引き継ぎはされません.

...

なので,

親スレッドから子スレッドに ThreadLocal の値を引き継ぐ

のようにスレッド単位ではなく,

タスクから子タスクThreadLocal の値を引き継ぐ

というタスク単位での仕組みが必要です.

最終章 ExecutionContext を override する

これを実現するためには,ExecutionContext を override して,タスク実行時に値を引き継ぐようにすると良さそうです.(いや,もっといい方法あるかもしれません...) ということで書いてみたのが以下の通りです.

implicit class RichExecutionContext(ec: ExecutionContext) {
    def withSharedString: ExecutionContext = new ExecutionContext {
        override def execute(task: Runnable) {
            val copyValue = SharedString.v.value
            ec.execute(new Runnable {
                override def run = {
                    SharedString.v.value_=(copyValue)
                    task.run
                }
            })
        }

        override def reportFailure(cause: Throwable): Unit = ec.reportFailure _
    }
}

使う側は,任意の ExecutionContext に対して,上記を使って新しい ExecutionContext を生成し,それをDIさせたコンテキスト下で,withValue を使えば以下のように Future を使ったコードでもうまく動かすことができます.

implicit val ec = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(2)).withSharedString

def f1() = Future {
    withValue("hoge") {
        SharedString.value shouldBe "hoge"
        printf("SharedString.value(%s) should be %s\n", SharedString.value, "hoge")
        val f2 = Future {
                SharedString.value shouldBe "hoge"
                printf("SharedString.value(%s) should be %s\n", SharedString.value, "hoge")
        }
        Await.result(f2, Duration.Inf)
        }
}

Await.result(f1(), Duration.Inf)
Await.result(f1(), Duration.Inf)

これで,ブロック内の処理がどんな処理であっても動くようになったのではないでしょうか. 本当は, ExecutionContext を override するところまで withValue の関数内でやりたかったのですが,implicit value の上書きとかは難しそうなので,使う側で ExecutionContext をよしなに設定してもらうようにしました.

ということで,最終的な全体像は下記の通りです.

object SharedString {
    var v = new DynamicVariable[String]("default")

    def value = v.value

    def value_=(s: String) = {
        v.value_=(s)
    }
}

def withValue[T](s: String)(block: => T) {
    SharedString.value_=(s)
    try {
        block
    } finally {
        SharedString.value_=("default")
    }
}

implicit class RichExecutionContext(ec: ExecutionContext) {
    def withSharedString: ExecutionContext = new ExecutionContext {
        override def execute(task: Runnable) {
            val copyValue = SharedString.v.value
            ec.execute(new Runnable {
                override def run = {
                    SharedString.v.value_=(copyValue)
                    task.run
                }
            })
        }

        override def reportFailure(cause: Throwable): Unit = ec.reportFailure _
    }
}

もう少しいい書き方あるよ,という方はぜひコメントください m( ) m