(* cont.sml * Michael Sullivan * Continuations for the HOT compiler. *) (* A signature for a "lame" monomorphic first class continuation. * It calls a function with an exn expecting continuation. * The continuation is a function instead of a special cont type. * The normal objection to this (that it limits polymorphism), does * not apply, since it isn't polymorphic at all anyways. * The goal is for the type to be stupidly simple. *) signature LAME_CONT = sig val lamecc : ((exn -> unit) -> exn) -> exn end (* Implementation of LameCont for use with the HOT compiler. * It relies on a magic transformation done during the CPS conversion * stage of compilation. The CPS converter needs to detect * lambdas with a particular argument type and rewrite them to * be approximately lamecc. * * Specifically, the CPS converter does the following: * Whenever it detects a function that takes an argument of the type * (((unit * unit) * unit) * unit * unit) * ((exn -> unit) -> exn), * it discards the body of the function and replaces it with one that * calls the given function with the "current continuation". *) structure LameCont :> LAME_CONT = struct (* Dummy type for indicating to the CPS converter which function to * translate. It needs to find the function to translate by its argument * type, but ((exn -> unit) -> exn) may well come up, especially in * programs using lamecc. Thus, we choose a "fairly implausible" * combination of unit types. * We should be /mostly/ safe using it. *) type dummy = (((unit * unit) * unit) * unit * unit) val dummy_arg : dummy = ((((),()),()),(),()) (* This is a "fake" version of the lamecc helper function. It does nothing * useful, but has the right type. The translator will convert it to to * actually doing a call with the current continuation. *) fun lamecc' (_ : dummy, _ : (exn -> unit) -> exn) : exn = Fail "placeholder" fun lamecc f = lamecc' (dummy_arg, f) end (* (* A LAME_CONT implementation on top of SML/NJ's callcc. * This is deeply silly. *) structure LameContNJ :> LAME_CONT = struct structure CC = SMLofNJ.Cont fun lamecc f = CC.callcc (fn k => f (CC.throw k)) end *) (* A signature for a straightforward polymorphic call/cc. *) signature CONT = sig type 'a cont val callcc : ('a cont -> 'a) -> 'a val throw : 'a cont -> 'a -> 'b end (* Implementation of Cont on top of LameCont *) structure Cont :> CONT = struct type 'a cont = ((exn -> unit) * ('a -> exn)) fun 'a callcc f = let exception E of 'a val (E x) = LameCont.lamecc (fn k => E (f (k, E))) in x end fun throw (k, E) x = let val () = k (E x) in raise Fail "You cannot get here" end end (*************************************************************************) (* amb.sml * by Jacob Potter * based on a less-polymorphic version by Michael Sullivan *) signature AMB = sig exception AmbFail val amb : 'a list -> 'a end structure Amb :> AMB = struct structure CC = Cont exception AmbFail val stack : exn CC.cont list ref = ref nil fun 'a amb nil = (case stack of ref nil => raise AmbFail | ref (last::rest) => ( stack := rest; CC.throw last AmbFail ) ) | amb (x::xs) = let exception E of 'a fun next cont = (stack := cont :: (!stack); E x) in case CC.callcc next of E y => y | _ => amb xs end end fun test0 () = let val x = Amb.amb [1, 2, 3, 4, 5] val _ = if not (x = 4) then Amb.amb nil else 0 in x end fun test1 () = let val x = Amb.amb [1, 2, 3] val y = Amb.amb [4, 5, 6] val _ = if not (x*y = 10) then Amb.amb nil else 0 in (x, y) end fun foldl _ z [] = z | foldl f z (x::xs) = foldl f (f (x, z)) xs fun map _ [] = [] | map f (x::xs) = (f x) :: map f xs fun length l = foldl (fn (_,n) => n+1) 0 l fun concatWith _ [] = "" | concatWith _ [x] = x | concatWith s (x::xs) = x ^ s ^ (concatWith s xs) (*debugging*) val str = Int.toString fun strtuple (n,m) = "("^(str n)^", "^(str m)^")" fun strlist f l = "[" ^ (concatWith ", " (map f l)) ^ "]" val strintlist = strlist str fun test2 () = let val amb = Amb.amb val x = amb [ 2, 3, 4, 5 ] val y = amb [ [1], [1,1,1,1,1], [1,1,1,1,1,1,1] ] in if not (x = (length y)) then amb nil else y end val x0 = test0 () val _ = print ((str x0) ^ "\n") val x1 = test1 () val _ = print ((strtuple x1) ^ "\n") val x2 = test2 () val _ = print ((strintlist x2) ^ "\n")