M.Hiroi's Home Page
http://www.geocities.jp/m_hiroi/

Functional Programming

お気楽 OCaml プログラミング入門

[ PrevPage | OCaml | NextPage ]

メモ化と遅延評価

今回は「たらいまわし関数」を例題にして、「メモ化」と「遅延評価」について説明します。なお、このドキュメントは拙作のページ Algorithms with Python 再帰定義 (たらいまわし関数) のプログラムを OCaml で書き直したものです。内容は重複していますが、ご了承くださいませ。

●たらいまわし関数

最初に「たらいまわし関数」について説明します。次のリストを見てください。

リスト 1 : たらいまわし関数

let rec tarai x y z =
  if x <= y then y
  else tarai (tarai (x - 1) y z) (tarai (y - 1) z x) (tarai (z - 1) x y)

let rec tak x y z = 
  if x <= y then z
  else tak (tak (x - 1) y z) (tak (y - 1) z x) (tak (z - 1) x y)

関数 tarai や tak は「たらいまわし関数」といって、再帰的に定義されています。これらの関数は、引数の与え方によっては実行に時間がかかるため、Lisp などのベンチマークに利用されることがあります。Common Lisp のプログラムは ぬえ 鵺 NUETAK Function にあります。

関数 tarai は通称「竹内関数」と呼ばれていて、日本の代表的な Lisper である竹内郁雄先生によって考案されたそうです。そして、関数 tak は関数 tarai のバリエーションで、John Macarthy によって作成されたそうです。たらいまわし関数が Lisp のベンチマークで使われていたことは知っていましたが、このような由緒ある関数だとは思ってもいませんでした。

それでは、さっそく実行してみましょう。実行環境は Windows XP, celeron 1.40 GHz, ocamlc (version 3.10.0) です。

tarai 12 6 0 : 1.25 [s]
tak 18 9 0   : 1.22 [s]

このように、たらいまわし関数は引数の値が小さくても実行に時間がかかります。

●メモ化による高速化

たらいまわし関数が遅いのは、同じ値を何度も計算しているためです。この場合、表 (table) を使って処理を高速化することができます。同じ値を何度も計算することがないように、計算した値は表に格納しておいて、2 回目以降は表から計算結果を求めるようにします。このような手法を「表計算法」とか「メモ化 (memoization または memoisation) 」といいます。

OCaml の場合、メモ化はハッシュ表 (Hashtbl) を使うと簡単です。次のリストを見てください。

リスト 2 : たらいまわし関数のメモ化 (1)

(* メモ用のハッシュ表 *)
let table = Hashtbl.create 2048

let rec tarai x y z =
  let key = (x, y, z) in
  if Hashtbl.mem table key then Hashtbl.find table key
  else
    let value = if x <= y then y
    else
      tarai (tarai (x - 1) y z) (tarai (y - 1) z x) (tarai (z - 1) x y)
    in
    Hashtbl.add table key value;
    value

関数 tarai の値を格納するハッシュ表を大域変数 table に用意します。関数 tarai では、引数 x, y, z を要素とする組を作り、それをキーとしてハッシュ表 table を検索します。table に key があれば、その値を返します。そうでなければ、値を計算して table にセットして、その値を返します。

ところで、ハッシュ表は局所変数に格納することもできます。次のリストを見てください。

リスト 3 : たらいまわし関数のメモ化 (2)

(* 探索 *)
let lookup table func args =
  if Hashtbl.mem table args then
    Hashtbl.find table args
  else
    let value = func args in
    Hashtbl.add table args value;
    value

(* たらいまわし関数 *)
let rec tak (x, y, z) =
  if x <= y then z
  else memo_tak (memo_tak (x - 1, y, z),
                 memo_tak (y - 1, z, x),
                 memo_tak (z - 1, x, y))
and memo_tak =
  let table = Hashtbl.create 2048 in
  fun x -> lookup table tak x

let rec tarai (x, y, z) =
  if x <= y then y
  else memo_tarai (memo_tarai (x - 1, y, z),
                   memo_tarai (y - 1, z ,x),
                   memo_tarai (z - 1, x, y))
and memo_tarai =
  let table = Hashtbl.create 2048 in
  fun x -> lookup table tarai x

関数 lookup はハッシュ表 table から関数 func の引数 args に対応するデータを探します。ここでは関数の引数を組にまとめて args に渡すものとします。ハッシュ表にデータがある場合はその値を返します。そうでなければ、func args を評価して値 value を求め、それをハッシュ表に登録します。

関数 tak と tarai は自分自身を再帰呼び出しするのではなく、関数 memo_tak と memo_tarai を呼び出します。memo_tak と memo_tarai は、ハッシュ表を局所変数 table にセットしてから、匿名関数を使って関数本体を定義します。ハッシュ表が生成されるのは、memo_tak, memo_tarai に関数をセットするときの一回だけです。これで、その関数専用のハッシュ表を局所変数に用意することができます。memo_tak と memo_tarai の本体は、lookup を呼び出してハッシュ表から値を探索するだけです。

関数の型は次のようになります。

val memoize : ('a, 'b) Hashtbl.t -> ('a -> 'b) -> 'a -> 'b = <fun>
val tak : int * int * int -> int = <fun>
val memo_tak : int * int * int -> int = <fun>
val tarai : int * int * int -> int = <fun>
val memo_tarai : int * int * int -> int = <fun>

それでは実際に実行してみましょう。実行環境は Windows XP, celeron 1.40 GHz, ocamlc (version 3.10.0) です。

tarai (192, 96, 0) : 0.26 [s]
tak (192, 96, 0)   : 1.91 [s]

このように、引数の値を増やしても高速に実行することができます。メモ化の効果は十分に出ていると思います。また、同じ計算を再度実行すると、メモ化の働きにより値をすぐに求めることができます。

●メモ化関数

このように関数をメモ化することは簡単にできますが、メモ化を行うたびに関数を修正するのは面倒です。このような場合、関数をメモ化する「メモ化関数」があると便利です。メモ化関数については Structure and Interpretation of Computer Programs (SICP) 3.3.3 Representing Tables に詳しい説明があります。

ただし、変数の値を書き換えることができない関数型言語の場合、汎用的なメモ化関数を作成することは難しく、OCaml でも簡単ではありません。そこで、今回は Lisp を使ってメモ化関数を作成してみましょう。Common Lisp と Scheme のプログラムは次のようになります。

リスト 4 : メモ化関数 (Common Lisp)

(defun memoize (func)
  (let ((table (make-hash-table :test #'equal)))
    #'(lambda (&rest args)
        (let ((value (gethash args table nil)))
          (unless value
            (setf value (apply func args))
            (setf (gethash args table) value))
          value))))

; たらいまわし関数
(defun tak (x y z)
  (if (<= x y)
      z
      (tak (tak (- x 1) y z) (tak (- y 1) z x) (tak (- z 1) x y))))

(defun tarai (x y z)
  (if (<= x y)
      y
      (tarai (tarai (- x 1) y z) (tarai (- y 1) z x) (tarai (- z 1) x y))))

; 関数を書き換える
(setf (symbol-function 'tak) (memoize #'tak))
(setf (symbol-function 'tarai) (memoize #'tarai))
リスト 5 : メモ化関数 (Scheme : Gauche)

; 汎用のメモ化関数
(define (memoize func)
  (let ((table (make-hash-table 'equal?)))
    (lambda args
      (if (hash-table-exists? table args)
          (hash-table-get table args)
          (let ((value (apply func args)))
            (hash-table-put! table args value)
            value)))))

; たらいまわし関数
(define (tak x y z)
  (if (<= x y)
      z
      (tak (tak (- x 1) y z) (tak (- y 1) z x) (tak (- z 1) x y))))

(define (tarai x y z)
  (if (<= x y)
      y
      (tarai (tarai (- x 1) y z) (tarai (- y 1) z x) (tarai (- z 1) x y))))

; 値を書き換える
(set! tak (memoize tak))
(set! tarai (memoize tarai))

関数 memoize は関数 func を引数に受け取り、それをメモ化した関数を返します。memoize が返す関数はクロージャなので、memoize の引数 func や局所変数 table にアクセスすることができます。また、無名関数 lambda の引数 args は可変個の引数を受け取るように定義します。これで、複数の引数を持つ関数にも対応することができます。

args の値は引数を格納したリストになるので、これをキーとして扱います。ハッシュ表 table に値がなければ、関数 func を呼び出して値を計算し、それを table にセットします。そしで、最後に値を返します。なお、変数 tak と tarai の値 (Common Lisp の場合は関数) を書き換えないと、関数 tak, tarai の中で再帰呼び出しするとき、メモ化した関数を呼び出すことはできません。ご注意ください。

●遅延評価による高速化

関数 tarai は「遅延評価 (delayed evaluation または lazy evaluation) 」を行う処理系、たとえば関数型言語の Haskell では高速に実行することができます。また、Scheme でも delay と force を使って遅延評価を行うことができます。tarai のプログラムを見てください。x <= y のときに y を返しますが、このとき引数 z の値は必要ありませんね。引数 z の値は x > y のときに計算するようにすれば、無駄な計算を省略することができます。

なお、関数 tak は x <= y のときに z を返しているため、遅延評価で高速化することはできません。ご注意ください。

OCaml には遅延評価を行うための構文 lazy とモジュール Lazy が用意されています。また、完全ではありませんが、クロージャを使って遅延評価を行うこともできます。今回は Shiro さんWiLiKi にある Scheme:たらいまわしべんち を参考に、プログラムを作ってみましょう。次のリストを見てください。

リスト 6 : クロージャによる遅延評価

let rec tarai x y z =
  if x <= y then y
  else
    let zz = z () in
    tarai (tarai (x - 1) y (fun () -> zz))
          (tarai (y - 1) zz (fun () -> x))
          (fun () -> tarai (zz - 1) x (fun () -> y))

遅延評価したい処理をクロージャに包んで引数 z に渡します。そして、x > y のときに引数 z の関数を呼び出します。すると、クロージャ内の処理が評価されて z の値を求めることができます。たとえば、fun () -> 0 を z に渡す場合、z () とすると返り値は 0 になります。fun () -> x を渡せば、x に格納されている値が返されます。fun -> tarai ... を渡せば、関数 tarai が実行されてその値が返されるわけです。

関数 tarai の型は次のようになります。

val tarai : int -> int -> (unit -> int) -> int = <fun>

また、lazy 文を使うと、tarai は次のようになります。

リスト 7 : lazy による遅延評価

let rec tarai x y z =
  if x <= y then y
  else
    let zz = Lazy.force z in
    tarai (tarai (x - 1) y (lazy zz))
          (tarai (y - 1) zz (lazy x))
          (lazy (tarai (zz - 1) x (lazy y)))

lazy expr は、式 expr を評価せずに lazy_t というデータ (遅延オブジェクト) を返します。簡単な使用例を示しましょう。

# let a = lazy (10 + 20);;
val a : int lazy_t = <lazy>
# Lazy.force a;;
- : int = 30

lazy (10 + 20) の返り値を変数 a にセットします。このとき、式 10 + 20 は評価されていません。遅延オブジェクトの値を実際に求める関数が Lazy.force です。Lazy.force a を実行すると、式 10 + 20 を評価して値 30 を返します。

また、遅延オブジェクトは式の評価結果をキャッシュします。したがって、Lazy.force a を再度実行すると、同じ式を再評価することなく値を求めることができます。次の例を見てください。

# let a = lazy (print_string "eval"; 10 + 20);;
val a : int lazy_t = <lazy>
# Lazy.force a;;
eval- : int = 30
# Lazy.force a;;
- : int = 30

最初に Lazy.force a を実行すると、式 (print_string "eval"; 10 + 20) が評価されるので、画面に eval が表示されます。次に、Lazy.force a を実行すると、式を評価せずにキャッシュした値を返すので eval は表示されません。

lazy を使った場合、関数 tarai の型は次のようになります。

val tarai : int -> int -> int Lazy.t -> int = <fun>

それでは、実際に実行してみましょう。実行環境は Windows XP, celeron 1.40 GHz, ocamlc (version 3.10.0) です。

tarai 192 96 0
closure : 0.0031 [s]
lazy    : 0.0078 [s]

実行時間が速いので、今回は tarai 192 96 0 を 10 回実行した時間から 1 回の実行時間を求めました。tarai の場合、遅延評価の効果はとても大きいですね。lazy は処理がクロージャよりも複雑になるので、処理速度は少し遅くなるようです。

ところで、クロージャや lazy を使わなくても、関数 tarai を高速化する方法があります。C++:language&libraries (cppll)Akira Higuchi さん が書かれたC言語の tarai 関数はとても高速です。OCaml でプログラムすると次のようになります。

リスト 8 : tarai の遅延評価

let rec tarai x y z =
  if x <= y then y
  else tarai_lazy (tarai (x - 1) y z) (tarai (y - 1) z x) (z - 1) x y
and tarai_lazy x y xx yy zz =
  if x <= y then y
  else
    let z = tarai xx yy zz in
    tarai_lazy (tarai (x - 1) y z) (tarai (y - 1) z x) (z - 1) x y
val tarai : int -> int -> int -> int = <fun>
val tarai_lazy : int -> int -> int -> int -> int -> int = <fun>

関数 tarai_lazy の引数 xx, yy, zz で z の値を表すところがポイントです。つまり、z の計算に必要な値を引数に保持し、z の値が必要になったときに tarai(xx, yy, zz) で計算するわけです。実際に実行してみると tarai 192 96 0 は 0.0016 [s] になりました。Akira Higuchi さんに感謝いたします。


遅延ストリーム

「ストリーム (stream) 」はデータの流れを抽象化したデータ構造です。たとえば、ファイル入出力はストリームと考えることができます。また、リストを使ってストリームを表すこともできます。ただし、単純なリストでは有限個のデータの流れしか表すことができません。ところが、遅延評価を用いると擬似的に無限個のデータを表すことができるようになります。これを「遅延ストリーム」とか「遅延リスト」と呼びます。今回は遅延ストリームについて説明します。

●遅延ストリームの構造

遅延ストリームの基本的な考え方は、必要になったときに新しいデータを生成することです。このときに遅延評価を用います。具体的にはデータを生成する関数を用意し、それを遅延評価してストリームに格納しておきます。そして、必要になった時点で遅延評価しておいた関数を呼び出して値を求めればよいわけです。

遅延ストリームのデータ型は次のようになります。

リスト 9 : 遅延ストリームのデータ型

open Lazy;;

(* 例外 *)
exception Empty_stream

(* データ型 *)
type 'a stream = Nils | Cons of 'a * 'a stream lazy.t

データ型は 'a stream としました。Nils はストリームの終端を表します。無限ストリームだけを扱うのであれば Nils は必要ありません。Cons が遅延ストリームの本体を表していて、組の最初の要素が現時点での先頭データを表し、次の要素が遅延ストリームを生成する関数を格納する遅延オブジェクト (lazy.t) です。この要素を force することで、次の要素を格納した遅延ストリームを生成します。

●遅延ストリームの生成

それでは、遅延ストリームを生成する関数を作りましょう。たとえば、low から high までの整数列を生成するストリームは次のようにプログラムすることができます。

リスト 10 : 整数列を生成するストリーム

let rec intgen low high =
  if low > high then Nils
  else Cons (low, lazy (intgen (low + 1) high))

関数 intgen は遅延ストリームを生成して返します。Cons の第 1 要素が現時点でのデータになります。第 2 要素の遅延オブジェクトを force すると、intgen (low + 1) high が実行されて次のデータを格納した遅延ストリームが返されます。そして、その中の遅延オブジェクトを force すると、その次のデータを得ることができます。

関数 intgen の型は次のようになります。

val intgen : int -> int -> int stream = <fun>

簡単な例を示しましょう。

# let s0 = intgen 1 100;;
val s0 : int stream = Cons (1, <lazy>)
# let Cons(n1, s1) = s0;;
... 警告 (省略) ...
val n1 : int = 1
val s1 : int stream lazy_t = <lazy>
# let Cons(n2, s2) = force s1;;
... 警告 (省略) ...
val n2 : int = 2
val s2 : int stream lazy_t = <lazy>
# let Cons(n3, s3) = force s2;;
... 警告 (省略) ...
val n3 : int = 3
val s3 : int stream lazy_t = <lazy>

このように、第 2 要素の遅延オブジェクトを force することで、次々とデータを生成することができます。

もう一つ、簡単な例を示しましょう。フィボナッチ数列を生成する遅延ストリームを作ります。次のリストを見てください。

リスト 11 : フィボナッチ数列を生成する遅延ストリーム

let rec fibgen a b = Cons (a, lazy (fibgen b (a + b)))

関数 fibgen の型は次のようになります。

val fibgen : int -> int -> int stream = <fun>

関数 fibgen の引数 a がフィボナッチ数列の最初の項で、b が次の項です。したがって、遅延オブジェクトに fibgen b (a + b) を格納しておけば、force することでフィボナッチ数列を生成することができます。なお、この関数は無限ストリームを生成しますが、OCaml の整数 (int) には上限値があるので、際限なくフィボナッチ数列を生成できるわけではありません。ご注意ください。

●遅延ストリームの操作関数

次は遅延ストリームを操作する関数を作りましょう。最初は n 番目の要素を求める関数 stream_ref です。今回は先頭の要素を 1 番目とします。

リスト 12 : n 番目の要素を求める

let rec stream_ref s n =
  match s with
    Nils -> raise Empty_stream
  | Cons (x, _) when n = 1 -> x
  | Cons (_, tail) -> stream_ref (force tail) (n - 1)

関数 stream_ref の型は次のようになります。

val stream_ref : 'a stream -> int -> 'a = <fun>

stream_ref は Cons の第 2 要素 tail を force してデータを生成し、それを n 回繰り返すことで n 番目の要素を求めます。force tail は遅延ストリームを返すことに注意してください。あとは、stream_ref を n 回再帰呼び出しすればいいわけです。

ストリームから n 個の要素を取り出してリストに格納して返す関数 stream_take も同様にプログラムすることができます。

リスト 13 : n 個の要素を取り出す

let rec stream_take s n =
  match s with
    Nils -> raise Empty_stream
  | Cons(x, tail) -> if n = 1 then [x]
                        else x :: stream_take (force tail) (n - 1)

関数 stream_take の型は次のようになります。

val stream_take : 'a stream -> int -> 'a list = <fun>

stream_take を再帰呼び出ししてストリームのデータ x を生成します。そして、n が 1 の場合はリスト [x] を返し、そうでなければ x を stream_ref の返り値 (リスト) に追加します。

それでは、簡単な実行例を示しましょう。

# let s1 = fibgen 1 1;;
val s1 : int stream = Cons (1, <lazy>)
# for i = 1 to 10 do print_int (stream_ref s1 i); print_newline() done;;
1
1
2
3
5
8
13
21
34
55
- : unit = ()
# stream_take s1 10;;
- : int list = [1; 1; 2; 3; 5; 8; 13; 21; 34; 55]

変数 s1 にフィボナッチ数列を生成するストリームをセットします。stream_ref で順番に要素を 10 個取り出すと、その値はフィボナッチ数列になっていますね。同様に、stream_take で 10 個の要素を取り出すと、リストの要素はフィボナッチ数列になります。

●高階関数

ところで、遅延ストリームは高階関数も定義することができます。次のリストを見てください。

リスト 14 : 高階関数

(* マップ *)
let rec stream_map proc = function
  Nils -> Nils
| Cons (x, tail) -> Cons (proc x, lazy (stream_map proc (force tail)))

(* フィルター *)
let rec stream_filter pred = function
  Nils -> Nils
| Cons (x, tail) when pred x -> Cons(x, lazy (stream_filter pred (force tail)))
| Cons (_, tail) -> stream_filter pred (force tail)

(* 畳み込み *)
let rec stream_fold_left proc a = function
  Nils -> a
| Cons (x, tail) -> stream_fold_left proc (proc a x) (force tail)

let rec stream_fold_right proc a = function
  Nils -> a
| Cons (x, tail) -> proc x (stream_fold_right proc a (force tail))

関数の型は次のようになります。

val stream_map : ('a -> 'b) -> 'a stream -> 'b stream = <fun>
val stream_filter : ('a -> bool) -> 'a stream -> 'a stream = <fun>
val stream_fold_left : ('a -> 'b -> 'a) -> 'a -> 'b stream -> 'a = <fun>
val stream_fold_right : ('a -> 'b -> 'b) -> 'b -> 'a stream -> 'b = <fun>

stream_map と stream_filter は関数と遅延ストリームを受け取り、新しい遅延ストリームを生成して返します。stream_map は引数のストリームの要素に関数 proc を適用した結果を新しいストリームに格納して返します。stream_filter は述語 pred が真を返す要素だけを新しいストリームに格納して返します。

stream_fold_left と stream_fold_right は遅延ストリームに対して畳み込み処理を行います。無限ストリームの場合は処理が終了しないので注意してください。

簡単な実行例を示しましょう。

# let s1 = intgen 1 100;;
val s1 : int stream = Cons (1, <lazy>)
# let s2 = stream_map (fun x -> x * x) s1;;
val s2 : int stream = Cons (1, <lazy>)
# stream_take s2 10;;
- : int list = [1; 4; 9; 16; 25; 36; 49; 64; 81; 100]
# let s3 = stream_filter (fun x -> x mod 2 = 0) s1;;
val s3 : int stream = Cons (2, <lazy>)
# stream_take s3 10;;
- : int list = [2; 4; 6; 8; 10; 12; 14; 16; 18; 20]
# stream_fold_left (+) 0 s1;;
- : int = 5050
# stream_fold_right (+) 0 s1;;
- : int = 5050

変数 s1 に 1 から始まる整数列を生成するストリームをセットします。次に、s1 の要素を 2 乗するストリームを stream_map で生成して変数 s2 にセットします。stream_take で s2 から要素を 10 個取り出すと、s1 の要素を 2 乗した値になります。

s1 から偶数列のストリームを得るには、引数が偶数のときに真を返す述語を stream_filter に渡します。その返り値を変数 s3 にセットして、stream_take で 10 個の要素を取り出すと、リストの要素は 2 から 20 までの値になります。

s1 は有限個の遅延ストリームなので畳み込みを行うことができます。stream_fold_left と stream_fold_right で要素の合計値を求めると 5050 になります。

今回はここまでです。次回は遅延ストリームを使って素数や順列を求めてみましょう。

●参考文献 (URL)

  1. "Structure and Interpretation of Computer Programs (SICP)" 3.5 Streams

Copyright (C) 2008 Makoto Hiroi
All rights reserved.

[ PrevPage | OCaml | NextPage ]