kdrag.kernel

The kernel hold core proof datatypes and core inference rules. By and large, all proofs must flow through this module.

Module Attributes

defns

defn holds definitional axioms for function symbols.

Functions

Inductive(name)

Declare datatypes with auto generated induction principles.

SchemaVar(prefix, sort)

Generate a fresh variable

axiom(thm[, by])

Assert an axiom.

compose(ab, bc)

Compose two implications.

consider(x)

The purpose of this is to seed the solver with interesting terms.

define(name, args, body[, lift_lambda])

Define a non recursive definition.

define_fix(name, args, retsort, fix_lam)

Define a recursive definition.

einstan(thm)

Skolemize an existential quantifier.

forget(ts, pf)

"Forget" a term using existentials.

forget2(ts, thm)

"Forget" a term using existentials.

fresh_const(q)

Generate fresh constants of same sort as quantifier.

generalize(vs, pf)

Generalize a theorem with respect to a list of schema variables.

herb(thm)

Herbrandize a theorem.

induct_inductive(x, P)

Build a basic induction principle for an algebraic datatype

instan(ts, pf)

Instantiate a universally quantified formula.

instan2(ts, thm)

Instantiate a universally quantified formula forall xs, P(xs) -> P(ts) This is forall elimination

is_defined(x)

Determined if expression head is in definitions.

is_proof(p)

is_schema_var(v)

Check if a variable is a schema variable.

modus(ab, a)

Modus ponens

prove(thm[, by, admit, timeout, dump, solver])

Prove a theorem using a list of previously proved lemmas.

skolem(pf)

Skolemize an existential quantifier.

substitute_schema_vars(pf, *subst)

Substitute schematic variables in a theorem.

Classes

Defn(name, args, body, ax)

A record storing definition.

Judgement()

Judgements should be constructed by smart constructors.

Proof(thm, reason[, admit])

It is unlikely that users should be accessing the Proof constructor directly.

Exceptions

LemmaError

class kdrag.kernel.Defn(name: str, args: list[ExprRef], body: ExprRef, ax: Proof)

Bases: object

A record storing definition. It is useful to record definitions as special axioms because we often must unfold them.

Parameters:
  • name (str)

  • args (list[ExprRef])

  • body (ExprRef)

  • ax (Proof)

args: list[ExprRef]
ax: Proof
body: ExprRef
name: str
kdrag.kernel.Inductive(name: str) Datatype

Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype

>>> Nat = Inductive("Nat")
>>> Nat.declare("zero")
>>> Nat.declare("succ", ("pred", Nat))
>>> Nat = Nat.create()
>>> Nat.succ(Nat.zero)
succ(zero)
Parameters:

name (str)

Return type:

Datatype

class kdrag.kernel.Judgement

Bases: object

Judgements should be constructed by smart constructors. Having an object of supertype judgement represents having shown some kind of truth. Judgements are the things that go above and below inference lines in a proof system. Don’t worry about it. It is just nice to have a name for the concept.

See: - https://en.wikipedia.org/wiki/Judgment_(mathematical_logic) - https://mathoverflow.net/questions/254518/what-exactly-is-a-judgement - https://ncatlab.org/nlab/show/judgment

exception kdrag.kernel.LemmaError

Bases: Exception

add_note(object, /)

Exception.add_note(note) – add a note to the exception

args
with_traceback(object, /)

Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.

class kdrag.kernel.Proof(thm: BoolRef, reason: list[Any], admit: bool = False)

Bases: Judgement

It is unlikely that users should be accessing the Proof constructor directly. This is not ironclad. If you really want the Proof constructor, I can’t stop you.

Parameters:
  • thm (BoolRef)

  • reason (list[Any])

  • admit (bool)

__call__(*args: ExprRef | Proof)
>>> x,y = smt.Ints("x y")
>>> p = prove(smt.ForAll([y], smt.ForAll([x], x >= x - 1)))
>>> p(x)
|- ForAll(x, x >= x - 1)
>>> p(x, smt.IntVal(7))
|- 7 >= 7 - 1
>>> a,b,c = smt.Bools("a b c")
>>> ab = prove(smt.Implies(a,smt.Implies(a, a)))
>>> a = axiom(a)
>>> ab(a)
|- Implies(a, a)
>>> ab(a,a)
|- a
Parameters:

args (ExprRef | Proof)

admit: bool = False
forall(schema_vars: list[ExprRef]) Proof

Generalize a proof involved schematic variables generated by SchemaVar

>>> x = SchemaVar("x", smt.IntSort())
>>> prove(x + 1 > x).forall([x])
|- ForAll(x!..., x!... + 1 > x!...)
Parameters:

schema_vars (list[ExprRef])

Return type:

Proof

reason: list[Any]
thm: BoolRef
kdrag.kernel.SchemaVar(prefix: str, sort: SortRef) ExprRef

Generate a fresh variable

>>> SchemaVar("x", smt.IntSort()).schema_evidence
_SchemaVarEvidence(v=x!...)
Parameters:
  • prefix (str)

  • sort (SortRef)

Return type:

ExprRef

kdrag.kernel.axiom(thm: BoolRef, by=['axiom']) Proof

Assert an axiom.

Axioms are necessary and useful. But you must use great care.

Parameters:
  • thm (BoolRef) – The axiom to assert.

  • by – A python object explaining why the axiom should exist. Often a string explaining the axiom.

Return type:

Proof

kdrag.kernel.compose(ab: Proof, bc: Proof) Proof

Compose two implications. Useful for chaining implications.

>>> a,b,c = smt.Bools("a b c")
>>> ab = axiom(smt.Implies(a, b))
>>> bc = axiom(smt.Implies(b, c))
>>> compose(ab, bc)
|- Implies(a, c)
Parameters:
Return type:

Proof

kdrag.kernel.consider(x: ExprRef) Proof

The purpose of this is to seed the solver with interesting terms. Axiom schema. We may give a fresh name to any constant. An “anonymous” form of define. Pointing out the interesting terms is sometimes the essence of a proof.

Parameters:

x (ExprRef)

Return type:

Proof

kdrag.kernel.define(name: str, args: list[ExprRef], body: ExprRef, lift_lambda=False) FuncDeclRef

Define a non recursive definition. Useful for shorthand and abstraction. Does not currently defend against ill formed definitions. TODO: Check for bad circularity, record dependencies

Parameters:
  • name (str) – The name of the term to define.

  • args (list[ExprRef]) – The arguments of the term.

  • defn – The definition of the term.

  • body (ExprRef)

Returns:

A tuple of the defined term and the proof of the definition.

Return type:

tuple[smt.FuncDeclRef, Proof]

kdrag.kernel.define_fix(name: str, args: list[ExprRef], retsort, fix_lam) FuncDeclRef

Define a recursive definition.

Parameters:
  • name (str)

  • args (list[ExprRef])

Return type:

FuncDeclRef

kdrag.kernel.defns: dict[FuncDeclRef, Defn] = {absR: Defn(name='absR', args=[x], body=If(x >= 0, x, -x), ax=|- ForAll(x, absR(x) == If(x >= 0, x, -x))), add: Defn(name='add', args=[x, y], body=If(is(Z, x), y, S(add(pred(x), y))), ax=|- ForAll([x, y],        add(x, y) == If(is(Z, x), y, S(add(pred(x), y))))), add: Defn(name='add', args=[f, g], body=Lambda(x, f[x] + g[x]), ax=|- ForAll([f, g], add(f, g) == (Lambda(x, f[x] + g[x])))), add: Defn(name='add', args=[a, b], body=Lambda(i, a[i] + b[i]), ax=|- ForAll([a, b], add(a, b) == (Lambda(i, a[i] + b[i])))), add: Defn(name='add', args=[x, y], body=x + y, ax=|- ForAll([x, y], add(x, y) == x + y)), add: Defn(name='add', args=[z1, z2], body=C(re(z1) + re(z2), im(z1) + im(z2)), ax=|- ForAll([z1, z2],        add(z1, z2) == C(re(z1) + re(z2), im(z1) + im(z2)))), add: Defn(name='add', args=[x, y], body=If(And(is(Real, x), is(Real, y)),    Real(val(x) + val(y)),    If(And(is(Inf, x), Not(is(NegInf, y))),       Inf,       If(And(Not(is(NegInf, x)), is(Inf, y)),          Inf,          If(And(is(NegInf, x), Not(is(Inf, y))),             NegInf,             If(And(Not(is(Inf, x)), is(NegInf, y)),                NegInf,                add_undef(x, y)))))), ax=|- ForAll([x, y],        add(x, y) ==        If(And(is(Real, x), is(Real, y)),           Real(val(x) + val(y)),           If(And(is(Inf, x), Not(is(NegInf, y))),              Inf,              If(And(Not(is(NegInf, x)), is(Inf, y)),                 Inf,                 If(And(is(NegInf, x), Not(is(Inf, y))),                    NegInf,                    If(And(Not(is(Inf, x)), is(NegInf, y)),                       NegInf,                       add_undef(x, y)))))))), add: Defn(name='add', args=[u, v], body=Vec2(x(u) + x(v), y(u) + y(v)), ax=|- ForAll([u, v], add(u, v) == Vec2(x(u) + x(v), y(u) + y(v)))), add: Defn(name='add', args=[u, v], body=Vec3(x0(u) + x0(v), x1(u) + x1(v), x2(u) + x2(v)), ax=|- ForAll([u, v],        add(u, v) ==        Vec3(x0(u) + x0(v), x1(u) + x1(v), x2(u) + x2(v)))), add: Defn(name='add', args=[i, j], body=Interval(lo(i) + lo(j), hi(i) + hi(j)), ax=|- ForAll([i, j],        add(i, j) == Interval(lo(i) + lo(j), hi(i) + hi(j)))), add: Defn(name='add', args=[u, v], body=If(shape(u) == shape(v),    NDArray(shape(u), Lambda(k, data(u)[k] + data(v)[k])),    add_undef(u, v)), ax=|- ForAll([u, v],        add(u, v) ==        If(shape(u) == shape(v),           NDArray(shape(u),                   Lambda(k, data(u)[k] + data(v)[k])),           add_undef(u, v)))), add_defined: Defn(name='add_defined', args=[x, y], body=Or(And(is(Real, x), is(Real, y)),    And(is(Inf, x), Not(is(NegInf, y))),    And(Not(is(NegInf, x)), is(Inf, y)),    And(is(NegInf, x), Not(is(Inf, y))),    And(Not(is(Inf, x)), is(NegInf, y))), ax=|- ForAll([x, y],        add_defined(x, y) ==        Or(And(is(Real, x), is(Real, y)),           And(is(Inf, x), Not(is(NegInf, y))),           And(Not(is(NegInf, x)), is(Inf, y)),           And(is(NegInf, x), Not(is(Inf, y))),           And(Not(is(Inf, x)), is(NegInf, y))))), cauchy_mod: Defn(name='cauchy_mod', args=[a, mod], body=ForAll(eps,        Implies(eps > 0,                ForAll([m, k],                       Implies(And(m > mod[eps],                                   k > mod[eps]),                               absR(a[m] - a[k]) < eps)))), ax=|- ForAll([a, mod],        cauchy_mod(a, mod) ==        (ForAll(eps,                Implies(eps > 0,                        ForAll([m, k],                               Implies(And(m > mod[eps],                                         k > mod[eps]),                                       absR(a[m] - a[k]) <                                       eps))))))), circle: Defn(name='circle', args=[c, r], body=Lambda(p, norm2(sub(p, c)) == r*r), ax=|- ForAll([c, r],        circle(c, r) == (Lambda(p, norm2(sub(p, c)) == r*r)))), comp: Defn(name='comp', args=[f, g], body=Lambda(x, f[g[x]]), ax=|- ForAll([f, g], comp(f, g) == (Lambda(x, f[g[x]])))), conj: Defn(name='conj', args=[z], body=C(re(z), -im(z)), ax=|- ForAll(z, conj(z) == C(re(z), -im(z)))), const: Defn(name='const', args=[x], body=K(Real, x), ax=|- ForAll(x, const(x) == K(Real, x))), cont_at: Defn(name='cont_at', args=[f, x], body=ForAll(eps,        Implies(eps > 0,                Exists(delta,                       And(delta > 0,                           ForAll(y,                                  Implies(absR(x - y) < delta,                                         absR(f[x] - f[y]) <                                         eps)))))), ax=|- ForAll([f, x],        cont_at(f, x) ==        (ForAll(eps,                Implies(eps > 0,                        Exists(delta,                               And(delta > 0,                                   ForAll(y,                                         Implies(absR(x - y) <                                         delta,                                         absR(f[x] - f[y]) <                                         eps))))))))), cross: Defn(name='cross', args=[u, v], body=Vec3(x1(u)*x2(v) - x2(u)*x1(v),      x2(u)*x0(v) - x0(u)*x2(v),      x0(u)*x1(v) - x1(u)*x0(v)), ax=|- ForAll([u, v],        cross(u, v) ==        Vec3(x1(u)*x2(v) - x2(u)*x1(v),             x2(u)*x0(v) - x0(u)*x2(v),             x0(u)*x1(v) - x1(u)*x0(v)))), delta: Defn(name='delta', args=[n, x], body=Lambda(n, If(n == 0, x, 0)), ax=|- ForAll([n, x], delta(n, x) == (Lambda(n, If(n == 0, x, 0))))), diff_at: Defn(name='diff_at', args=[f, x], body=Exists(y, has_diff_at(f, x, y)), ax=|- ForAll([f, x],        diff_at(f, x) == (Exists(y, has_diff_at(f, x, y))))), dist: Defn(name='dist', args=[u, v], body=sqrt(norm2(sub(u, v))), ax=|- ForAll([u, v], dist(u, v) == sqrt(norm2(sub(u, v))))), div_: Defn(name='div_', args=[f, g], body=Lambda(x, f[x]/g[x]), ax=|- ForAll([f, g], div_(f, g) == (Lambda(x, f[x]/g[x])))), div_: Defn(name='div_', args=[a, b], body=Lambda(i, a[i]/b[i]), ax=|- ForAll([a, b], div_(a, b) == (Lambda(i, a[i]/b[i])))), div_: Defn(name='div_', args=[z1, z2], body=C((re(z1)*re(z2) + im(z1)*im(z2))/(re(z2)**2 + im(z2)**2),   (im(z1)*re(z2) - re(z1)*im(z2))/(re(z2)**2 + im(z2)**2)), ax=|- ForAll([z1, z2],        div_(z1, z2) ==        C((re(z1)*re(z2) + im(z1)*im(z2))/          (re(z2)**2 + im(z2)**2),          (im(z1)*re(z2) - re(z1)*im(z2))/          (re(z2)**2 + im(z2)**2)))), div_: Defn(name='div_', args=[u, v], body=Vec3(x0(u)/x0(v), x1(u)/x1(v), x2(u)/x2(v)), ax=|- ForAll([u, v],        div_(u, v) ==        Vec3(x0(u)/x0(v), x1(u)/x1(v), x2(u)/x2(v)))), dot: Defn(name='dot', args=[u, v], body=x(u)*x(v) + y(u)*y(v), ax=|- ForAll([u, v], dot(u, v) == x(u)*x(v) + y(u)*y(v))), dot: Defn(name='dot', args=[u, v], body=0 + x0(u)*x0(v) + x1(u)*x1(v) + x2(u)*x2(v), ax=|- ForAll([u, v],        dot(u, v) ==        0 + x0(u)*x0(v) + x1(u)*x1(v) + x2(u)*x2(v))), double: Defn(name='double', args=[n], body=If(is(Z, n), Z, S(S(double(pred(n))))), ax=|- ForAll(n,        double(n) == If(is(Z, n), Z, S(S(double(pred(n))))))), even: Defn(name='even', args=[x], body=Exists(y, x == 2*y), ax=|- ForAll(x, even(x) == (Exists(y, x == 2*y)))), expi: Defn(name='expi', args=[t], body=C(cos(t), sin(t)), ax=|- ForAll(t, expi(t) == C(cos(t), sin(t)))), finite: Defn(name='finite', args=[A], body=Exists(finwit!493,        ForAll(x!492,               A[x!492] == Contains(finwit!493, Unit(x!492)))), ax=|- ForAll(A,        finite(A) ==        (Exists(finwit!493,                ForAll(x!492,                       A[x!492] ==                       Contains(finwit!493, Unit(x!492))))))), floor: Defn(name='floor', args=[x], body=ToReal(ToInt(x)), ax=|- ForAll(x, floor(x) == ToReal(ToInt(x)))), from_int: Defn(name='from_int', args=[a], body=If(a <= 0, Z, S(from_int(a - 1))), ax=|- ForAll(a, from_int(a) == If(a <= 0, Z, S(from_int(a - 1))))), has_lim_at: Defn(name='has_lim_at', args=[f, p, L], body=ForAll(eps,        Implies(0 < eps,                Exists(delta,                       And(delta > 0,                           ForAll(x,                                  Implies(And(0 < absR(x - p),                                         absR(x - p) < delta),                                         absR(f[x] - L) < eps)))))), ax=|- ForAll([f, p, L],        has_lim_at(f, p, L) ==        (ForAll(eps,                Implies(0 < eps,                        Exists(delta,                               And(delta > 0,                                   ForAll(x,                                         Implies(And(0 <                                         absR(x - p),                                         absR(x - p) < delta),                                         absR(f[x] - L) < eps))))))))), ident: Defn(name='ident', args=[], body=Lambda(x, x), ax=|- ident == (Lambda(x, x))), is_cauchy: Defn(name='is_cauchy', args=[a], body=ForAll(eps,        Implies(eps > 0,                Exists(N,                       ForAll([m, k],                              Implies(And(m > N, k > N),                                      absR(a[m] - a[k]) < eps))))), ax=|- ForAll(a,        is_cauchy(a) ==        (ForAll(eps,                Implies(eps > 0,                        Exists(N,                               ForAll([m, k],                                      Implies(And(m > N,                                         k > N),                                         absR(a[m] - a[k]) <                                         eps)))))))), is_circle: Defn(name='is_circle', args=[A], body=Exists([c, r], circle(c, r) == A), ax=|- ForAll(A,        is_circle(A) == (Exists([c, r], circle(c, r) == A)))), is_cont: Defn(name='is_cont', args=[f], body=ForAll(x, cont_at(f, x)), ax=|- ForAll(f, is_cont(f) == (ForAll(x, cont_at(f, x))))), is_convergent: Defn(name='is_convergent', args=[a], body=ForAll(eps,        Implies(eps > 0,                Exists(N,                       ForAll(m,                              Implies(m > N,                                      Exists(x,                                         absR(a[m] - x) < eps)))))), ax=|- ForAll(a,        is_convergent(a) ==        (ForAll(eps,                Implies(eps > 0,                        Exists(N,                               ForAll(m,                                      Implies(m > N,                                         Exists(x,                                         absR(a[m] - x) < eps))))))))), is_diff: Defn(name='is_diff', args=[f], body=ForAll(x, diff_at(f, x)), ax=|- ForAll(f, is_diff(f) == (ForAll(x, diff_at(f, x))))), join: Defn(name='join', args=[i, j], body=Interval(min(lo(i), lo(j)), max(hi(i), hi(j))), ax=|- ForAll([i, j],        join(i, j) ==        Interval(min(lo(i), lo(j)), max(hi(i), hi(j))))), le: Defn(name='le', args=[x, y], body=If(x == y,    True,    If(is(NegInf, x),       True,       If(is(Inf, y),          True,          If(is(NegInf, y),             False,             If(is(Inf, x),                False,                If(And(is(Real, x), is(Real, y)),                   val(x) <= val(y),                   unreachable!441)))))), ax=|- ForAll([x, y],        le(x, y) ==        If(x == y,           True,           If(is(NegInf, x),              True,              If(is(Inf, y),                 True,                 If(is(NegInf, y),                    False,                    If(is(Inf, x),                       False,                       If(And(is(Real, x), is(Real, y)),                          val(x) <= val(y),                          unreachable!441)))))))), max: Defn(name='max', args=[x, y], body=If(x >= y, x, y), ax=|- ForAll([x, y], max(x, y) == If(x >= y, x, y))), meet: Defn(name='meet', args=[i, j], body=Interval(max(lo(i), lo(j)), min(hi(i), hi(j))), ax=|- ForAll([i, j],        meet(i, j) ==        Interval(max(lo(i), lo(j)), min(hi(i), hi(j))))), mid: Defn(name='mid', args=[i], body=(lo(i) + hi(i))/2, ax=|- ForAll(i, mid(i) == (lo(i) + hi(i))/2)), min: Defn(name='min', args=[x, y], body=If(x <= y, x, y), ax=|- ForAll([x, y], min(x, y) == If(x <= y, x, y))), mul: Defn(name='mul', args=[f, g], body=Lambda(x, f[x]*g[x]), ax=|- ForAll([f, g], mul(f, g) == (Lambda(x, f[x]*g[x])))), mul: Defn(name='mul', args=[a, b], body=Lambda(i, a[i]*b[i]), ax=|- ForAll([a, b], mul(a, b) == (Lambda(i, a[i]*b[i])))), mul: Defn(name='mul', args=[x, y], body=x*y, ax=|- ForAll([x, y], mul(x, y) == x*y)), mul: Defn(name='mul', args=[z1, z2], body=C(re(z1)*re(z2) - im(z1)*im(z2),   re(z1)*im(z2) + im(z1)*re(z2)), ax=|- ForAll([z1, z2],        mul(z1, z2) ==        C(re(z1)*re(z2) - im(z1)*im(z2),          re(z1)*im(z2) + im(z1)*re(z2)))), mul: Defn(name='mul', args=[u, v], body=Vec3(x0(u)*x0(v), x1(u)*x1(v), x2(u)*x2(v)), ax=|- ForAll([u, v],        mul(u, v) ==        Vec3(x0(u)*x0(v), x1(u)*x1(v), x2(u)*x2(v)))), neg: Defn(name='neg', args=[u], body=Vec2(-x(u), -y(u)), ax=|- ForAll(u, neg(u) == Vec2(-x(u), -y(u)))), neg: Defn(name='neg', args=[u], body=Vec3(-x0(u), -x1(u), -x2(u)), ax=|- ForAll(u, neg(u) == Vec3(-x0(u), -x1(u), -x2(u)))), nonneg: Defn(name='nonneg', args=[x], body=absR(x) == x, ax=|- ForAll(x, nonneg(x) == (absR(x) == x))), norm2: Defn(name='norm2', args=[z], body=mul(z, conj(z)), ax=|- ForAll(z, norm2(z) == mul(z, conj(z)))), norm2: Defn(name='norm2', args=[u], body=dot(u, u), ax=|- ForAll(u, norm2(u) == dot(u, u))), norm2: Defn(name='norm2', args=[u], body=x0(u)*x0(u) + x1(u)*x1(u) + x2(u)*x2(u), ax=|- ForAll(u,        norm2(u) == x0(u)*x0(u) + x1(u)*x1(u) + x2(u)*x2(u))), odd: Defn(name='odd', args=[x], body=Exists(y, x == 2*y + 1), ax=|- ForAll(x, odd(x) == (Exists(y, x == 2*y + 1)))), ones: Defn(name='ones', args=[n], body=NDArray(Unit(n), K(Int, 1)), ax=|- ForAll(n, ones(n) == NDArray(Unit(n), K(Int, 1)))), pow: Defn(name='pow', args=[x, y], body=x**y, ax=|- ForAll([x, y], pow(x, y) == x**y)), safe_pred: Defn(name='safe_pred', args=[n], body=If(is(Z, n), Z, pred(n)), ax=|- ForAll(n, safe_pred(n) == If(is(Z, n), Z, pred(n)))), seqlim: Defn(name='seqlim', args=[a, y], body=ForAll(eps,        Implies(eps > 0,                Exists(N,                       ForAll(n,                              Implies(n > N,                                      absR(a[n] - y) < eps))))), ax=|- ForAll([a, y],        seqlim(a, y) ==        (ForAll(eps,                Implies(eps > 0,                        Exists(N,                               ForAll(n,                                      Implies(n > N,                                         absR(a[n] - y) < eps)))))))), setof: Defn(name='setof', args=[i], body=Lambda(x, And(lo(i) <= x, x <= hi(i))), ax=|- ForAll(i,        setof(i) == (Lambda(x, And(lo(i) <= x, x <= hi(i)))))), sgn: Defn(name='sgn', args=[x], body=If(x > 0, 1, If(x < 0, -1, 0)), ax=|- ForAll(x, sgn(x) == If(x > 0, 1, If(x < 0, -1, 0)))), sqr: Defn(name='sqr', args=[x], body=x*x, ax=|- ForAll(x, sqr(x) == x*x)), sqrt: Defn(name='sqrt', args=[x], body=x**(1/2), ax=|- ForAll(x, sqrt(x) == x**(1/2))), sub: Defn(name='sub', args=[f, g], body=Lambda(x, f[x] - g[x]), ax=|- ForAll([f, g], sub(f, g) == (Lambda(x, f[x] - g[x])))), sub: Defn(name='sub', args=[a, b], body=Lambda(i, a[i] - b[i]), ax=|- ForAll([a, b], sub(a, b) == (Lambda(i, a[i] - b[i])))), sub: Defn(name='sub', args=[x, y], body=x - y, ax=|- ForAll([x, y], sub(x, y) == x - y)), sub: Defn(name='sub', args=[u, v], body=Vec2(x(u) - x(v), y(u) - y(v)), ax=|- ForAll([u, v], sub(u, v) == Vec2(x(u) - x(v), y(u) - y(v)))), sub: Defn(name='sub', args=[u, v], body=Vec3(x0(u) - x0(v), x1(u) - x1(v), x2(u) - x2(v)), ax=|- ForAll([u, v],        sub(u, v) ==        Vec3(x0(u) - x0(v), x1(u) - x1(v), x2(u) - x2(v)))), sub: Defn(name='sub', args=[i, j], body=Interval(lo(i) - hi(j), hi(i) - lo(j)), ax=|- ForAll([i, j],        sub(i, j) == Interval(lo(i) - hi(j), hi(i) - lo(j)))), tan: Defn(name='tan', args=[x], body=sin(x)/cos(x), ax=|- ForAll(x, tan(x) == sin(x)/cos(x))), to_int: Defn(name='to_int', args=[x], body=If(Length(val(x)) == 0,    0,    BV2Int(Nth(val(x), 0)) +    2*    to_int(BitVecN(seq.extract(val(x), 1, Length(val(x)) - 1)))), ax=|- ForAll(x,        to_int(x) ==        If(Length(val(x)) == 0,           0,           BV2Int(Nth(val(x), 0)) +           2*           to_int(BitVecN(seq.extract(val(x),                                      1,                                      Length(val(x)) - 1)))))), to_int: Defn(name='to_int', args=[n], body=If(is(Z, n), 0, 1 + to_int(pred(n))), ax=|- ForAll(n, to_int(n) == If(is(Z, n), 0, 1 + to_int(pred(n))))), wf: Defn(name='wf', args=[x], body=Implies(is(real, x), val(x) >= 0), ax=|- ForAll(x, wf(x) == Implies(is(real, x), val(x) >= 0))), width: Defn(name='width', args=[i], body=hi(i) - lo(i), ax=|- ForAll(i, width(i) == hi(i) - lo(i))), zero: Defn(name='zero', args=[n], body=NDArray(Unit(n), K(Int, 0)), ax=|- ForAll(n, zero(n) == NDArray(Unit(n), K(Int, 0))))}

defn holds definitional axioms for function symbols.

kdrag.kernel.einstan(thm: QuantifierRef) tuple[list[ExprRef], Proof]

Skolemize an existential quantifier. exists xs, P(xs) -> P(cs) for fresh cs https://en.wikipedia.org/wiki/Existential_instantiation

Parameters:

thm (QuantifierRef)

Return type:

tuple[list[ExprRef], Proof]

kdrag.kernel.forget(ts: Iterable[ExprRef], pf: Proof) Proof

“Forget” a term using existentials. This is existential introduction. This could be derived from forget2

Parameters:
  • ts (Iterable[ExprRef])

  • pf (Proof)

Return type:

Proof

kdrag.kernel.forget2(ts: Sequence[ExprRef], thm: QuantifierRef) Proof

“Forget” a term using existentials. This is existential introduction. P(ts) -> exists xs, P(xs) thm is an existential formula, and ts are terms to substitute those variables with. forget easily follows. https://en.wikipedia.org/wiki/Existential_generalization

Parameters:
  • ts (Sequence[ExprRef])

  • thm (QuantifierRef)

Return type:

Proof

kdrag.kernel.fresh_const(q: QuantifierRef)

Generate fresh constants of same sort as quantifier.

Parameters:

q (QuantifierRef)

kdrag.kernel.generalize(vs: list[ExprRef], pf: Proof) Proof

Generalize a theorem with respect to a list of schema variables. This introduces a universal quantifier for schema variables.

>>> x = SchemaVar("x", smt.IntSort())
>>> y = SchemaVar("y", smt.IntSort())
>>> generalize([x, y], prove(x == x))
|- ForAll([x!..., y!...], x!... == x!...)
Parameters:
  • vs (list[ExprRef])

  • pf (Proof)

Return type:

Proof

kdrag.kernel.herb(thm: QuantifierRef) tuple[list[ExprRef], Proof]

Herbrandize a theorem. It is sufficient to prove a theorem for fresh consts to prove a universal. Note: Perhaps lambdaized form is better? Return vars and lamda that could receive |- P[vars]

Parameters:

thm (QuantifierRef)

Return type:

tuple[list[ExprRef], Proof]

kdrag.kernel.induct_inductive(x: DatatypeRef, P: QuantifierRef) Proof

Build a basic induction principle for an algebraic datatype

Parameters:
  • x (DatatypeRef)

  • P (QuantifierRef)

Return type:

Proof

kdrag.kernel.instan(ts: Sequence[ExprRef], pf: Proof) Proof

Instantiate a universally quantified formula. This is forall elimination

Parameters:
  • ts (Sequence[ExprRef])

  • pf (Proof)

Return type:

Proof

kdrag.kernel.instan2(ts: Sequence[ExprRef], thm: BoolRef) Proof

Instantiate a universally quantified formula forall xs, P(xs) -> P(ts) This is forall elimination

Parameters:
  • ts (Sequence[ExprRef])

  • thm (BoolRef)

Return type:

Proof

kdrag.kernel.is_defined(x: ExprRef) bool

Determined if expression head is in definitions.

Parameters:

x (ExprRef)

Return type:

bool

kdrag.kernel.is_proof(p: Proof) bool
Parameters:

p (Proof)

Return type:

bool

kdrag.kernel.is_schema_var(v: ExprRef) bool

Check if a variable is a schema variable. Schema variables are generated by SchemaVar and have a _SchemaVarEvidence attribute.

>>> is_schema_var(SchemaVar("x", smt.IntSort()))
True
Parameters:

v (ExprRef)

Return type:

bool

kdrag.kernel.modus(ab: Proof, a: Proof) Proof

Modus ponens

>>> a,b = smt.Bools("a b")
>>> ab = axiom(smt.Implies(a, b))
>>> a = axiom(a)
>>> modus(ab, a)
|- b
Parameters:
Return type:

Proof

kdrag.kernel.prove(thm: BoolRef, by: Proof | Iterable[Proof] = [], admit=False, timeout=1000, dump=False, solver=None) Proof

Prove a theorem using a list of previously proved lemmas.

In essence prove(Implies(by, thm)).

Parameters:
  • thm (smt.BoolRef) – The theorem to prove.

  • thm – The theorem to prove.

  • by (list[Proof]) – A list of previously proved lemmas.

  • admit (bool) – If True, admit the theorem without proof.

Returns:

A proof object of thm

Return type:

Proof

>>> prove(smt.BoolVal(True))
|- True
>>> prove(smt.RealVal(1) >= smt.RealVal(0))
|- 1 >= 0
kdrag.kernel.skolem(pf: Proof) tuple[list[ExprRef], Proof]

Skolemize an existential quantifier.

Parameters:

pf (Proof)

Return type:

tuple[list[ExprRef], Proof]

kdrag.kernel.substitute_schema_vars(pf: Proof, *subst) Proof

Substitute schematic variables in a theorem. This is is single step instead of generalizing to a Forall and then eliminating it.

>>> x = SchemaVar("x", smt.IntSort())
>>> y = SchemaVar("y", smt.IntSort())
>>> substitute_schema_vars(prove(x == x), (x, smt.IntVal(42)), (y, smt.IntVal(43)))
|- 42 == 42
Parameters:

pf (Proof)

Return type:

Proof

"""
The kernel hold core proof datatypes and core inference rules. By and large, all proofs must flow through this module.
"""

import kdrag as kd
import kdrag.smt as smt
from dataclasses import dataclass
from typing import Any, Iterable, Sequence
import logging
from . import config

logger = logging.getLogger("knuckledragger")


class Judgement:
    """
    Judgements should be constructed by smart constructors.
    Having an object of supertype judgement represents having shown some kind of truth.
    Judgements are the things that go above and below inference lines in a proof system.
    Don't worry about it. It is just nice to have a name for the concept.

    See:
    - https://en.wikipedia.org/wiki/Judgment_(mathematical_logic)
    - https://mathoverflow.net/questions/254518/what-exactly-is-a-judgement
    - https://ncatlab.org/nlab/show/judgment
    """


@dataclass(frozen=True)
class Proof(Judgement):
    """
    It is unlikely that users should be accessing the `Proof` constructor directly.
    This is not ironclad. If you really want the Proof constructor, I can't stop you.
    """

    thm: smt.BoolRef
    reason: list[Any]
    admit: bool = False

    def __post_init__(self):
        if self.admit and not config.admit_enabled:
            raise ValueError(
                self.thm, "was called with admit=True but config.admit_enabled=False"
            )

    def __hash__(self) -> int:
        return hash(self.thm)

    def _repr_html_(self):
        return "&#8870;" + repr(self.thm)

    def __repr__(self):
        return "|- " + repr(self.thm)

    def forall(self, schema_vars: list[smt.ExprRef]) -> "Proof":
        """
        Generalize a proof involved schematic variables generated by SchemaVar

        >>> x = SchemaVar("x", smt.IntSort())
        >>> prove(x + 1 > x).forall([x])
        |- ForAll(x!..., x!... + 1 > x!...)
        """
        return generalize(schema_vars, self)

    def __call__(self, *args: "smt.ExprRef | Proof"):
        """

        >>> x,y = smt.Ints("x y")
        >>> p = prove(smt.ForAll([y], smt.ForAll([x], x >= x - 1)))
        >>> p(x)
        |- ForAll(x, x >= x - 1)
        >>> p(x, smt.IntVal(7))
        |- 7 >= 7 - 1

        >>> a,b,c = smt.Bools("a b c")
        >>> ab = prove(smt.Implies(a,smt.Implies(a, a)))
        >>> a = axiom(a)
        >>> ab(a)
        |- Implies(a, a)
        >>> ab(a,a)
        |- a
        """
        # Note: Not trusted code. Trusted code is in `instan` and `modus`
        acc = self
        n = 0
        while n < len(args):
            if isinstance(self.thm, smt.QuantifierRef) and self.thm.is_forall():
                i = self.thm.num_vars()
                acc = instan(args[n : n + i], acc)  # type: ignore
                n += i
            elif smt.is_implies(self.thm):
                x = args[n]
                n += 1
                assert isinstance(x, kd.Proof), "Can only apply implication to kd.Proof"
                acc = modus(acc, x)
            else:
                raise TypeError(
                    "Proofs can only be called with a single argument or a list of arguments for forall quantifiers. "
                    "Use instan for forall quantifiers or modus for implications."
                )
        return acc


"""
Proof_new = Proof.__new__

def sin_check(cls, thm, reason, admit=False, i_am_a_sinner=False):
    if admit and not config.admit_enabled:
        raise ValueError(
            thm, "was called with admit=True but config.admit_enabled=False"
        )
    if not i_am_a_sinner:
        raise ValueError("Proof is private. Use `kd.prove` or `kd.axiom`")
    return Proof_new(cls, thm, list(reason), admit)


Proof.__new__ = sin_check
"""


def is_proof(p: Proof) -> bool:
    return isinstance(p, Proof)


class LemmaError(Exception):
    pass


def prove(
    thm: smt.BoolRef,
    by: Proof | Iterable[Proof] = [],
    admit=False,
    timeout=1000,
    dump=False,
    solver=None,
) -> Proof:
    """Prove a theorem using a list of previously proved lemmas.

    In essence `prove(Implies(by, thm))`.

    :param thm: The theorem to prove.
    Args:
        thm (smt.BoolRef): The theorem to prove.
        by (list[Proof]): A list of previously proved lemmas.
        admit     (bool): If True, admit the theorem without proof.

    Returns:
        Proof: A proof object of thm

    >>> prove(smt.BoolVal(True))
    |- True
    >>> prove(smt.RealVal(1) >= smt.RealVal(0))
    |- 1 >= 0
    """
    if isinstance(by, Proof):
        by = [by]
    if admit:
        logger.warning("Admitting lemma {}".format(thm))
        return Proof(thm, list(by), admit=True)
    else:
        if solver is None:
            s = config.solver()  # type: ignore
        else:
            s = solver()
        s.set("timeout", timeout)
        for p in by:
            if not isinstance(p, Proof):
                raise LemmaError("In by reasons:", p, "is not a Proof object")
            s.add(p.thm)
        s.add(smt.Not(thm))
        if dump:
            print(s.sexpr())
        res = s.check()
        if res != smt.unsat:
            if res == smt.sat:
                raise LemmaError(thm, "Countermodel", s.model())
            raise LemmaError("prove", thm, res)
        else:
            return Proof(thm, list(by), False)


def axiom(thm: smt.BoolRef, by=["axiom"]) -> Proof:
    """Assert an axiom.

    Axioms are necessary and useful. But you must use great care.

    Args:
        thm: The axiom to assert.
        by: A python object explaining why the axiom should exist. Often a string explaining the axiom.
    """
    return Proof(thm, by)


@dataclass(frozen=True)
class Defn:
    """
    A record storing definition. It is useful to record definitions as special axioms because we often must unfold them.
    """

    name: str
    args: list[smt.ExprRef]
    body: smt.ExprRef
    ax: Proof


_datatypes = {}
defns: dict[smt.FuncDeclRef, Defn] = {}
"""
defn holds definitional axioms for function symbols.
"""
smt.FuncDeclRef.defn = property(lambda self: defns[self].ax)
smt.ExprRef.defn = property(lambda self: defns[self.decl()].ax)


def is_defined(x: smt.ExprRef) -> bool:
    """
    Determined if expression head is in definitions.
    """
    return smt.is_app(x) and x.decl() in defns


def fresh_const(q: smt.QuantifierRef):
    """Generate fresh constants of same sort as quantifier."""
    # .split("!") is to remove ugly multiple freshness from names
    return [
        smt.FreshConst(q.var_sort(i), prefix=q.var_name(i).split("!")[0])
        for i in range(q.num_vars())
    ]


def define(
    name: str, args: list[smt.ExprRef], body: smt.ExprRef, lift_lambda=False
) -> smt.FuncDeclRef:
    """
    Define a non recursive definition. Useful for shorthand and abstraction. Does not currently defend against ill formed definitions.
    TODO: Check for bad circularity, record dependencies

    Args:
        name: The name of the term to define.
        args: The arguments of the term.
        defn: The definition of the term.

    Returns:
        tuple[smt.FuncDeclRef, Proof]: A tuple of the defined term and the proof of the definition.
    """
    sorts = [arg.sort() for arg in args] + [body.sort()]
    f = smt.Function(name, *sorts)

    # TODO: This is getting too hairy for the kernel? Reassess. Maybe just a lambda flag? Autolift?
    if lift_lambda and isinstance(body, smt.QuantifierRef) and body.is_lambda():
        # It is worth it to avoid having lambdas in definition.
        vs = fresh_const(body)
        # print(vs, f(*args)[tuple(vs)])
        # print(smt.substitute_vars(body.body(), *vs))
        def_ax = axiom(
            smt.ForAll(
                args + vs,
                smt.Eq(
                    f(*args)[tuple(vs)], smt.substitute_vars(body.body(), *reversed(vs))
                ),
            ),
            by="definition",
        )
    elif len(args) == 0:
        def_ax = axiom(smt.Eq(f(), body), by="definition")
    else:
        def_ax = axiom(smt.ForAll(args, smt.Eq(f(*args), body)), by="definition")
    # assert f not in __sig or __sig[f].eq(   def_ax.thm)  # Check for redefinitions. This is kind of painful. Hmm.
    # Soft warning is more pleasant.
    defn = Defn(name, args, body, def_ax)
    if f not in defns or defns[f].ax.thm.eq(def_ax.thm):
        defns[f] = defn
    else:
        print("WARNING: Redefining function", f, "from", defns[f].ax, "to", def_ax.thm)
        defns[f] = defn
    if len(args) == 0:
        return f()  # Convenience
    else:
        return f


def define_fix(name: str, args: list[smt.ExprRef], retsort, fix_lam) -> smt.FuncDeclRef:
    """
    Define a recursive definition.
    """
    sorts = [arg.sort() for arg in args]
    sorts.append(retsort)
    f = smt.Function(name, *sorts)

    # wrapper to record calls
    calls = set()

    def record_f(*args):
        calls.add(args)
        return f(*args)

    defn = define(name, args, fix_lam(record_f))
    # TODO: check for well foundedness/termination, custom induction principle.
    return defn


def consider(x: smt.ExprRef) -> Proof:
    """
    The purpose of this is to seed the solver with interesting terms.
    Axiom schema. We may give a fresh name to any constant. An "anonymous" form of define.
    Pointing out the interesting terms is sometimes the essence of a proof.
    """
    return axiom(smt.Eq(smt.FreshConst(x.sort(), prefix="consider"), x))


def instan(ts: Sequence[smt.ExprRef], pf: Proof) -> Proof:
    """
    Instantiate a universally quantified formula.
    This is forall elimination
    """
    assert (
        is_proof(pf)
        and isinstance(pf.thm, smt.QuantifierRef)
        and pf.thm.is_forall()
        and len(ts) == pf.thm.num_vars()
    )

    return axiom(smt.substitute_vars(pf.thm.body(), *reversed(ts)), [pf])


def instan2(ts: Sequence[smt.ExprRef], thm: smt.BoolRef) -> Proof:
    """
    Instantiate a universally quantified formula
    `forall xs, P(xs) -> P(ts)`
    This is forall elimination
    """
    assert (
        isinstance(thm, smt.QuantifierRef)
        and thm.is_forall()
        and len(ts) == thm.num_vars()
    )

    return axiom(
        smt.Implies(thm, smt.substitute_vars(thm.body(), *reversed(ts))),
        ["forall_elim"],
    )


def forget(ts: Iterable[smt.ExprRef], pf: Proof) -> Proof:
    """
    "Forget" a term using existentials. This is existential introduction.
    This could be derived from forget2
    """
    # Hmm. I seem to have rarely been using this
    assert is_proof(pf)
    vs = [smt.FreshConst(t.sort()) for t in ts]
    return axiom(smt.Exists(vs, smt.substitute(pf.thm, *zip(ts, vs))), ["forget", pf])


def forget2(ts: Sequence[smt.ExprRef], thm: smt.QuantifierRef) -> Proof:
    """
    "Forget" a term using existentials. This is existential introduction.
    `P(ts) -> exists xs, P(xs)`
    `thm` is an existential formula, and `ts` are terms to substitute those variables with.
    forget easily follows.
    https://en.wikipedia.org/wiki/Existential_generalization
    """
    assert smt.is_quantifier(thm) and thm.is_exists() and len(ts) == thm.num_vars()
    return axiom(
        smt.Implies(smt.substitute_vars(thm.body(), *reversed(ts)), thm),
        ["exists_intro"],
    )


def einstan(thm: smt.QuantifierRef) -> tuple[list[smt.ExprRef], Proof]:
    """
    Skolemize an existential quantifier.
    `exists xs, P(xs) -> P(cs)` for fresh cs
    https://en.wikipedia.org/wiki/Existential_instantiation
    """
    # TODO: Hmm. Maybe we don't need to have a Proof? Lessen this to thm.
    assert smt.is_quantifier(thm) and thm.is_exists()

    skolems = fresh_const(thm)
    return skolems, axiom(
        smt.Implies(thm, smt.substitute_vars(thm.body(), *reversed(skolems))),
        ["einstan"],
    )


def skolem(pf: Proof) -> tuple[list[smt.ExprRef], Proof]:
    """
    Skolemize an existential quantifier.
    """
    # TODO: Hmm. Maybe we don't need to have a Proof? Lessen this to thm.
    assert is_proof(pf) and isinstance(pf.thm, smt.QuantifierRef) and pf.thm.is_exists()

    skolems = fresh_const(pf.thm)
    return skolems, axiom(
        smt.substitute_vars(pf.thm.body(), *reversed(skolems)), ["skolem", pf]
    )


def herb(thm: smt.QuantifierRef) -> tuple[list[smt.ExprRef], Proof]:
    """
    Herbrandize a theorem.
    It is sufficient to prove a theorem for fresh consts to prove a universal.
    Note: Perhaps lambdaized form is better? Return vars and lamda that could receive `|- P[vars]`
    """
    assert smt.is_quantifier(thm) and thm.is_forall()
    herbs = fresh_const(thm)  # We could mark these as schema variables? Useful?
    return herbs, axiom(
        smt.Implies(smt.substitute_vars(thm.body(), *reversed(herbs)), thm),
        ["herband"],
    )


def modus(ab: Proof, a: Proof) -> Proof:
    """
    Modus ponens

    >>> a,b = smt.Bools("a b")
    >>> ab = axiom(smt.Implies(a, b))
    >>> a = axiom(a)
    >>> modus(ab, a)
    |- b
    """
    assert isinstance(ab, Proof) and isinstance(a, Proof)
    assert smt.is_implies(ab.thm) and ab.thm.arg(0).eq(a.thm)
    return axiom(ab.thm.arg(1), ["modus", ab, a])


def compose(ab: Proof, bc: Proof) -> Proof:
    """
    Compose two implications. Useful for chaining implications.

    >>> a,b,c = smt.Bools("a b c")
    >>> ab = axiom(smt.Implies(a, b))
    >>> bc = axiom(smt.Implies(b, c))
    >>> compose(ab, bc)
    |- Implies(a, c)
    """
    assert isinstance(ab, Proof) and isinstance(bc, Proof)
    assert smt.is_implies(ab.thm) and smt.is_implies(bc.thm)
    assert ab.thm.arg(1).eq(bc.thm.arg(0))
    return axiom(smt.Implies(ab.thm.arg(0), bc.thm.arg(1)), ["compose", ab, bc])


def induct_inductive(x: smt.DatatypeRef, P: smt.QuantifierRef) -> Proof:
    """Build a basic induction principle for an algebraic datatype"""
    DT = x.sort()
    assert isinstance(DT, smt.DatatypeSortRef)
    """assert (
        isisntance(P,QuantififerRef) and P.is_lambda()
    )  # TODO: Hmm. Actually it should just be arraysort"""
    hyps = []
    for i in range(DT.num_constructors()):
        constructor = DT.constructor(i)
        args = [
            smt.FreshConst(constructor.domain(j), prefix=DT.accessor(i, j).name())
            for j in range(constructor.arity())
        ]
        head = P(constructor(*args))
        body = [P(arg) for arg in args if arg.sort() == DT]
        if len(args) == 0:
            hyps.append(head)
        else:
            hyps.append(kd.QForAll(args, *body, head))
    conc = P(x)
    return axiom(smt.Implies(smt.And(hyps), conc), by="induction_axiom_schema")


def Inductive(name: str) -> smt.Datatype:
    """
    Declare datatypes with auto generated induction principles. Wrapper around z3.Datatype

    >>> Nat = Inductive("Nat")
    >>> Nat.declare("zero")
    >>> Nat.declare("succ", ("pred", Nat))
    >>> Nat = Nat.create()
    >>> Nat.succ(Nat.zero)
    succ(zero)
    """
    counter = 0
    n = name
    while n in _datatypes:
        counter += 1
        n = name + "!" + str(counter)
    name = n
    assert name not in _datatypes
    dt = smt.Datatype(name)
    oldcreate = dt.create

    def create():
        dt = oldcreate()
        # Sanity check no duplicate names. Causes confusion.
        names = set()
        for i in range(dt.num_constructors()):
            cons = dt.constructor(i)
            n = cons.name()
            if n in names:
                raise Exception("Duplicate constructor name", n)
            names.add(n)
            for j in range(cons.arity()):
                n = dt.accessor(i, j).name()
                if n in names:
                    raise Exception("Duplicate field name", n)
                names.add(n)
        kd.notation.induct.register(dt, induct_inductive)
        _datatypes[name] = dt
        smt.sort_registry[dt.get_id()] = dt
        return dt

    dt.create = create
    return dt


# Experimental Schema Vars


@dataclass(frozen=True)
class _SchemaVarEvidence(Judgement):
    """
    Do not instantiate this class directly.
    Use `SchemaVar`. This class should always be created with a fresh variable.
    Holding this data type is considered evidence analogous to the `Proof` type that the var was generated freshly
    and hence is generic / schematic.

    One can prove theorem using this variable as a constant, but once it comes to generalize, you need to supply the evidence
    That it was originally generated freshly.
    """

    v: smt.ExprRef


def is_schema_var(v: smt.ExprRef) -> bool:
    """
    Check if a variable is a schema variable.
    Schema variables are generated by SchemaVar and have a _SchemaVarEvidence attribute.

    >>> is_schema_var(SchemaVar("x", smt.IntSort()))
    True
    """
    evidence = getattr(v, "schema_evidence")
    return isinstance(evidence, _SchemaVarEvidence) and evidence.v.eq(v)


def SchemaVar(prefix: str, sort: smt.SortRef) -> smt.ExprRef:
    """
    Generate a fresh variable

    >>> SchemaVar("x", smt.IntSort()).schema_evidence
    _SchemaVarEvidence(v=x!...)
    """
    v = smt.FreshConst(sort, prefix=prefix)
    v.schema_evidence = _SchemaVarEvidence(
        v
    )  # Is cyclic reference a garbage collection problem?
    return v


def generalize(vs: list[smt.ExprRef], pf: Proof) -> Proof:
    """
    Generalize a theorem with respect to a list of schema variables.
    This introduces a universal quantifier for schema variables.

    >>> x = SchemaVar("x", smt.IntSort())
    >>> y = SchemaVar("y", smt.IntSort())
    >>> generalize([x, y], prove(x == x))
    |- ForAll([x!..., y!...], x!... == x!...)
    """
    assert all(is_schema_var(v) for v in vs)
    assert isinstance(pf, Proof)
    return axiom(smt.ForAll(vs, pf.thm), by=["generalize", vs, pf])


def substitute_schema_vars(pf: Proof, *subst) -> Proof:
    """
    Substitute schematic variables in a theorem.
    This is is single step instead of generalizing to a Forall and then eliminating it.

    >>> x = SchemaVar("x", smt.IntSort())
    >>> y = SchemaVar("y", smt.IntSort())
    >>> substitute_schema_vars(prove(x == x), (x, smt.IntVal(42)), (y, smt.IntVal(43)))
    |- 42 == 42
    """
    assert all(is_schema_var(v) for v, t in subst) and isinstance(pf, Proof)
    return axiom(
        smt.substitute(pf.thm, *subst), by=["substitute_schema_vars", pf, subst]
    )