structure DefineUtils :> DefineUtils = struct open HolKernel boolLib bossLib open Parse open Term val num_infixes_defined: bool ref = ref false val monad_infixes_defined: bool ref = ref false val pair_case_tac = qpat_assum `_ (_: α # β) = _` (fn a => let val v = concl a |> lhs |> dest_comb |> #2 in if pairSyntax.is_pair v then NO_TAC else Cases_on `^v` end) ORELSE qpat_assum `_ = _ (_: α # β)` (fn a => let val v = concl a |> rhs |> dest_comb |> #2 in if pairSyntax.is_pair v then NO_TAC else Cases_on `^v` end) fun multi_case_tac x = (REPEAT ((BasicProvers.FULL_CASE_TAC ORELSE pair_case_tac) >> fs[] >> rw[])) x fun rw_assums xs = POP_ASSUM_LIST (map_every (assume_tac o REWRITE_RULE xs) o rev) fun define_sum(name : string, inl : string, inl_ty : hol_type, inr : string, inr_ty : hol_type) : unit = let val ty = ``:^inl_ty + ^inr_ty`` in type_abbrev(name, ty); overload_on(inl, mk_thy_const{ Thy = "sum", Name = "INL", Ty = inl_ty --> ty }); overload_on(inr, mk_thy_const{ Thy = "sum", Name = "INR", Ty = inr_ty --> ty }) end fun define_num_newtype(type_name : string, ctor_name : string) : unit = let val _ = Datatype[QUOTE type_name, QUOTE " = ", QUOTE ctor_name, QUOTE " num"] val ty = mk_type(type_name, []) val ctor = mk_const(ctor_name, mk_type("num", []) --> ty) val suc_name = type_name ^ "_suc" val lt_name = type_name ^ "_lt" val dest_name = type_name ^ "_dest" val suc_def = Define(QUOTE suc_name::` x = case x of ^ctor n => ^ctor (SUC n)`) val lt_def = Define(QUOTE lt_name::` x y = case x of ^ctor m => case y of ^ctor n => m < n`) val dest_def = Define(QUOTE dest_name::` (^ctor n) = n`) val suc = mk_const(suc_name, ty --> ty) val lt = mk_const(lt_name, ty --> ty --> mk_type("bool", [])) val _ = Q.store_thm(type_name ^ "_lt_suc[simp]", `∀x. ^lt x (^suc x)`, Cases >> rw[suc_def, lt_def]) val _ = Q.store_thm(type_name ^ "_lt_trans", `∀x y z. ^lt x y ∧ ^lt y z ⇒ ^lt x z`, Cases >> Cases >> Cases >> rw[lt_def]) val _ = Q.store_thm(type_name ^ "_lt_equiv[simp]", `∀x y. ^lt x y ⇒ ∃m n. (x = ^ctor m) ∧ (y = ^ctor n) ∧ m < n`, Cases >> Cases >> rw[lt_def]) in (* Define "precedes" (U+227A = ≺) and "predeces or equivalent" (U+227C = ≼) operators, * which are different Unicode chars from < and ≤. These new symbols are required for * overloaded comparisons because ≤ is not defined as an overload. *) if not (!num_infixes_defined) then ( set_fixity "≺" (Infix(NONASSOC, 450)); set_fixity "≼" (Infix(NONASSOC, 450)); num_infixes_defined := true ) else (); overload_on("≺", ``^lt``); overload_on("≼", ``λl r. (^lt l r) ∨ (l = r)``) (* TeX_notation {hol="≺", TeX=("\\ensuremath{\\prec}", 1)}; *) end fun QDefine(quoted : string frag list) = let fun unquote (QUOTE s) = QUOTE s | unquote (ANTIQUOTE s) = QUOTE s in Define (map unquote quoted) end val greek = ["'a", "'b", "'c", "'d", "'e", "'f", "'g", "'h", "'i", "'j", "'k"] fun normalize_ty (ty : hol_type) : hol_type = type_subst (map (op |->) (ListPair.zip (type_vars ty, map mk_vartype greek))) ty fun define_monad(name : string, ty : hol_type -> hol_type, return : term, bind : term) : unit = let val alpha = mk_vartype "'aa" val beta = mk_vartype "'bb" val return_ty = normalize_ty (alpha --> ty alpha) val return' = ``^return : ^(ty_antiq return_ty)`` val return_name = "return_" ^ name val _ = overload_on(return_name, return') val bind_ty = normalize_ty (ty alpha --> (alpha --> ty beta) --> ty beta) val bind' = ``^bind : ^(ty_antiq bind_ty)`` val bind_name = "bind_" ^ name val _ = overload_on(bind_name, bind') val lift_name = "lift_" ^ name val lift_def = QDefine `^lift_name f x = ^bind_name x (^return_name o f)` val _ = mk_const(lift_name, (alpha --> beta) --> ty alpha --> ty beta) val join_name = "join_" ^ name val join_def = QDefine `^join_name m = ^bind_name m I` val _ = mk_const(join_name, ty (ty alpha) --> ty alpha) val sequence_name = "sequence_" ^ name val sequence_def = QDefine`^sequence_name = FOLDR (λp q. ^bind_name p (λx. ^bind_name q (λy. ^return_name (x::y)))) (^return_name [])` val _ = mk_const(sequence_name, mk_type("list", [ty alpha]) --> ty (mk_type("list", [alpha]))) val mapM_name = "mapM_" ^ name val mapM_def = QDefine `^mapM_name f xs = ^sequence_name (MAP f xs)` val _ = mk_const(mapM_name, (alpha --> ty beta) --> mk_type("list", [alpha]) --> ty (mk_type("list", [beta]))) val ap_name = "ap_" ^ name val ap_def = QDefine `^ap_name f arg = ^bind_name f (λf'. ^bind_name arg (λarg'. ^return_name (f' arg')))` val _ = mk_const(ap_name, ty (alpha --> beta) --> ty alpha --> ty beta) in if not (!monad_infixes_defined) then ( set_fixity ">>=" (Infixl 660); set_fixity "*>" (Infixl 660); set_fixity "<$>" (Infixl 661); monad_infixes_defined := true ) else (); overload_on(">>=", bind'); overload_on("*>", ``λm1 m2. ^bind' m1 (K m2)``); overload_on("return", return'); overload_on("monad_bind", bind'); overload_on("monad_unitbind", ``λm1 m2. ^bind' m1 (K m2)``); overload_on("liftM", Term [QUOTE lift_name]); overload_on("join", Term [QUOTE join_name]); overload_on("sequence", Term [QUOTE sequence_name]); overload_on("mapM", Term [QUOTE mapM_name]); overload_on("ap", Term [QUOTE ap_name]); overload_on("<$>", Term [QUOTE lift_name]); overload_on("APPLICATIVE_FAPPLY", Term [QUOTE ap_name]) end end