(*Imp to Vm compiler *)

(* Compiler for arithmetic expressions *)
module Compile_aexpr

  use int.Int
  use list.List
  use list.Length
  use list.Append
  use imp.Imp
  use vm.Vm
  use state.State
  use logic.Compiler_logic
  use specs.VM_instr_spec

  (* Compilation scheme: the generated code for arithmetic expressions
     put the result of the expression on the stack. *)
  function aexpr_post (a:aexpr) (len:pos) : post 'a =
    fun _ p ms ms' -> let VMS _ s m = ms in ms' = VMS (p+len) (push (aeval m a) s) m
  meta rewrite_def function aexpr_post

  let rec compile_aexpr (a:aexpr) :  hl 'a
    ensures { result.pre --> trivial_pre }
    ensures { result.post --> aexpr_post a result.code.length }
    variant { a }
    = let c = match a with
      | Anum n     -> $ iconstf n
      | Avar x     -> $ ivarf x
      | Aadd a1 a2 -> $ compile_aexpr a1 -- $ compile_aexpr a2 --  $ iaddf ()
      | Asub a1 a2 -> $ compile_aexpr a1 -- $ compile_aexpr a2 --  $ isubf ()
      | Amul a1 a2 -> $ compile_aexpr a1 -- $ compile_aexpr a2 --  $ imulf ()
      end in
      hoare trivial_pre c (aexpr_post a c.wcode.length)

  (* Check that the above specification indeed implies the
     natural one. *)
  let compile_aexpr_natural (a:aexpr) : code
    ensures { forall c p s m. codeseq_at c p result ->
        transition_star c (VMS p s m)
                          (VMS (p + length result) (push (aeval m a) s) m) }
  = let res = compile_aexpr a : hl unit in
    assert { forall p s m. res.pre () p (VMS p s m) }; res.code


(* Compiler for boolean expressions. *)
module Compile_bexpr

  use int.Int
  use list.List
  use list.Length
  use list.Append
  use imp.Imp
  use vm.Vm
  use state.State
  use logic.Compiler_logic
  use specs.VM_instr_spec
  use Compile_aexpr

  (* Compilation scheme: the generated code perform a jump
     iff the boolean expression evaluate to cond. *)
  function bexpr_post (b:bexpr) (cond: bool) (out_t:ofs) (out_f:ofs) : post 'a =
    fun _ p ms ms' -> let VMS _ s m = ms in if beval m b = cond
        then ms' = VMS (p + out_t) s m
        else ms' = VMS (p + out_f) s m
  meta rewrite_def function bexpr_post

  function exec_cond (b1:bexpr) (cond:bool) : pre 'a =
    fun _ _ ms -> let VMS _ _ m = ms in beval m b1 = cond
  meta rewrite_def function exec_cond

  let rec compile_bexpr (b:bexpr) (cond:bool) (ofs:ofs) :  hl 'a
    ensures { result.pre --> trivial_pre }
    ensures { result.post --> let len = result.code.length in
      bexpr_post b cond (len + ofs) len }
    variant { b }
  = let c = match b with
    | Btrue      -> $ if cond then ibranchf ofs else inil ()
    | Bfalse     -> $ if cond then inil () else ibranchf ofs
    | Bnot b1    -> $ compile_bexpr b1 (not cond) ofs
    | Band b1 b2 ->
      let c2  = $ compile_bexpr b2 cond ofs % exec_cond b1 true in
      let ofs = if cond then length c2.wcode else ofs + length c2.wcode in
      $ compile_bexpr b1 false ofs -- c2
    | Beq a1 a2 -> $ compile_aexpr a1 -- $ compile_aexpr a2 --
                   $ if cond then ibeqf ofs else ibnef ofs
    | Ble a1 a2 -> $ compile_aexpr a1 -- $ compile_aexpr a2 --
                   $ if cond then iblef ofs else ibgtf ofs
    end in
    let ghost post = bexpr_post b cond (c.wcode.length + ofs) c.wcode.length in
    hoare trivial_pre c post

  (* Check that the above specification implies the natural one. *)
  let compile_bexpr_natural (b:bexpr) (cond:bool) (ofs:ofs) : code
    ensures { forall c p s m. codeseq_at c p result ->
        transition_star c (VMS p s m)
           (VMS (p + length result + if beval m b = cond then ofs else 0) s m) }
  = let res = compile_bexpr b cond ofs : hl unit in
    assert { forall p s m. res.pre () p (VMS p s m) }; res.code


module Compile_com

  use int.Int
  use list.List
  use list.Length
  use list.Append
  use imp.Imp
  use vm.Vm
  use state.State
  use logic.Compiler_logic
  use specs.VM_instr_spec
  use Compile_aexpr
  use Compile_bexpr

  (* Compilation scheme: the generated code for a command
     simulates the command on the memory part of the machine state. *)
  (* As we specify only terminating behavior, we have to require
     that the source program terminates in the initial conditions. *)
  function com_pre (cmd:com) : pre 'a =
    fun _ p ms -> let VMS p' _ m = ms in p = p' /\ exists m'. ceval m cmd m'
  meta rewrite_def function com_pre

  function com_post (cmd:com) (len:pos) : post 'a =
    fun _ _ ms ms' -> let VMS p s m = ms in let VMS p' s' m' = ms' in
      p' = p + len /\ s' = s /\ ceval m cmd m'
  meta rewrite_def function com_post

  function exec_cond_old (b1:bexpr) (cond:bool) : pre ('a,machine_state) =
    fun x _ _ -> let VMS _ _ m = snd x in beval m b1 = cond
  meta rewrite_def function exec_cond_old

  (* Invariant for loop compilation: any intermediate state
     would evaluate to the same final state as the initial state. *)
  function loop_invariant (c:com) : pre ('a,machine_state) =
    fun x p msi -> let VMS _ s0 m0 = snd x in let VMS pi si mi = msi in
      pi = p /\ s0 = si /\ exists mf. ceval m0 c mf /\ ceval mi c mf
  meta rewrite_def function loop_invariant

  function loop_variant (c:com) (test:bexpr) : post 'a =
    fun _ _ msj msi -> let VMS _pj _sj mj = msj in let VMS _pi _si mi = msi in
       ceval mi c mj /\ beval mi test
  lemma loop_variant_lemma : forall c test,x:'a,p msj msi.
    loop_variant c test x p msj msi =
      let VMS _pj _sj mj = msj in let VMS _pi _si mi = msi in
      ceval mi c mj /\ beval mi test
  meta rewrite lemma loop_variant_lemma

  (* Well-foundedness of the loop variant. *)
  lemma loop_variant_acc : forall c test,x:'a,p mi mj.
    let wh = Cwhile test c in let var = (loop_variant c test x p) in
    (ceval mi wh mj -> forall pi si. acc var (VMS pi si mi))
    by forall pi si mi mj mf. ceval mi c mj /\ beval mi test ->
      ceval mj wh mf /\ (forall pj sj. acc var (VMS pj sj mj)) ->
      acc var (VMS pi si mi) by
      (forall pk sk mk. var (VMS pk sk mk) (VMS pi si mi) -> mk = mj)

  let rec compile_com (cmd: com) : hl 'a
    ensures { result.pre --> com_pre cmd }
    ensures { result.post --> let len = result.code.length in com_post cmd len }
    variant  { cmd }
  = let res = match cmd with
    | Cskip              -> $ inil ()
    | Cassign x a        -> $ compile_aexpr a  -- $ isetvarf x
    | Cseq cmd1 cmd2     -> $ compile_com cmd1 -- $ compile_com cmd2
    | Cif cond cmd1 cmd2 -> let code_false = compile_com cmd2 in
      let code_true = $ compile_com cmd1 -- $ ibranchf code_false.code.length in
      $ compile_bexpr cond false code_true.wcode.length --
      (code_true % exec_cond cond true) --
      ($ code_false % exec_cond_old cond false)
    | Cwhile test body  ->
      let code_body = compile_com body in
      let body_length = length code_body.code + 1 in
      let code_test = compile_bexpr test false body_length in
      let ofs = length code_test.code + body_length in
      let wp_while = $ code_test --
          ($ code_body -- $ ibranchf (- ofs)) % exec_cond test true in
      let ghost inv = loop_invariant cmd in
      let ghost var = loop_variant body test in
      $ inil () -- make_loop wp_while inv (exec_cond test true) var
    end in
    hoare (com_pre cmd) res (com_post cmd res.wcode.length)

  (* Get back to natural specification for the compiler. *)
  let compile_com_natural (com: com) : code
    ensures { forall c p s m m'. ceval m com m' -> codeseq_at c p result ->
      transition_star c (VMS p s m) (VMS (p + length result) s m') }
  = let res = compile_com com : hl unit in
    assert { forall c p s m m'. ceval m com m' -> codeseq_at c p res.code ->
     res.pre () p (VMS p s m) && (forall ms'. res.post () p (VMS p s m) ms' ->
      ms' = VMS (p + length res.code) s m') };

  (* Insert the final halting instruction. *)
  let compile_program (prog : com) : code
    ensures { forall  mi mf: state.
      ceval mi prog mf -> vm_terminates result mi mf }
  = compile_com_natural prog ++ ihalt

  (* Execution test: compile a simple factorial program, e.g
     X := 1; WHILE NOT (Y <= 0) DO X := X * Y; Y := Y - 1 DONE
     (why3 execute -L . compiler.mlw Compile_com.test) *)
  let test () : code =
    let x = Id 0 in
    let y = Id 1 in
    let cond = Bnot (Ble (Avar y) (Anum 0)) in
    let body1 = Cassign x (Amul (Avar x) (Avar y)) in
    let body2 = Cassign y (Asub (Avar y) (Anum 1)) in
    let lp = Cwhile cond (Cseq body1 body2) in
    let code = Cseq (Cassign x (Anum 1)) lp in
    compile_program code

  let test2 () : code =
    compile_program (Cwhile Btrue Cskip)


