kdrag.solvers.egraph
Classes
|
- class kdrag.solvers.egraph.EGraph(proof=False)
Bases:
object
- add_term(t: ExprRef) None
Recursively add term to egraph
- Parameters:
t (ExprRef)
- Return type:
None
- copy()
Copy the egraph. This is a shallow copy, so the terms are not copied.
>>> E = EGraph() >>> f = smt.Function('f', smt.IntSort(), smt.IntSort()) >>> x,y,z = smt.Ints('x y z') >>> E.add_term(f(x)) >>> E.add_term(f(y)) >>> _ = E.union(x,y) >>> assert E.find(f(x)) != E.find(f(y)) >>> E2 = E.copy() >>> _ = E2.rebuild() >>> assert E2.find(f(x)) == E2.find(f(y)) >>> assert E.find(f(x)) != E.find(f(y))
- dot(filename: str = 'egraph') Digraph
Create graphviz representation of the egraph.
>>> E = EGraph() >>> x,y,z = smt.Ints("x y z") >>> E.add_term(x + y) >>> E.union(y,z) True >>> _ = E.rebuild() >>> _ = E.dot()
- Parameters:
filename (str)
- Return type:
Digraph
- eclasses() defaultdict[int, defaultdict[FuncDeclRef, set[tuple[int]]]]
Returns a dictionary mapping each term to its equivalence class.
>>> E = EGraph() >>> x,y,z = smt.Ints("x y z") >>> E.add_term(x + y) >>> E.union(y,z) True >>> _ = E.rebuild() >>> _ = E.eclasses()
- Return type:
defaultdict[int, defaultdict[FuncDeclRef, set[tuple[int]]]]
- ematch(vs: list[ExprRef], pat: ExprRef) list[list[ExprRef]]
Find all matches of pat in the egraph.
>>> E = EGraph() >>> f = smt.Function('f', smt.IntSort(), smt.IntSort()) >>> x,y,z = smt.Ints('x y z') >>> E.add_term(f(x)) >>> _ = E.union(f(x), x) >>> _ = E.rebuild() >>> E.ematch([y], f(f(y))) [[x]]
- Parameters:
vs (list[ExprRef])
pat (ExprRef)
- Return type:
list[list[ExprRef]]
- extract(t0: ~z3.z3.ExprRef, cost_fun=<function EGraph.<lambda>>)
Extract a term from the egraph.
>>> E = EGraph() >>> x,y,z = smt.Ints('x y z') >>> E.add_term(x + y) >>> _ = E.rebuild() >>> E.extract(x + y) x + y >>> _ = E.union(x + y, y) >>> _ = E.rebuild() >>> E.extract(x + y) y
- Parameters:
t0 (ExprRef)
- find(t: ExprRef) int
Get canonical id of term in egraph.
- Parameters:
t (ExprRef)
- Return type:
int
- get_proof(t1: ExprRef, t2: ExprRef) list[object]
Get the proof of why t1 == t2 in the egraph. The reasons returns may require recursive calls of get_proof.
>>> E = EGraph(proof=True) >>> x,y,z = smt.Ints('x y z') >>> E.add_term(x + y) >>> _ = E.union(x + y, y, reason="because I said so") >>> _ = E.union(x + y, x, reason="because I said so too") >>> _ = E.union(x + y, z, reason="because I said so three") >>> list(sorted(E.get_proof(x, y), key=lambda x: x[2])) [(x + y, y, 'because I said so'), (x + y, x, 'because I said so too')]
- Parameters:
t1 (ExprRef)
t2 (ExprRef)
- Return type:
list[object]
- in_terms(t: ExprRef) bool
Semantically check if t is in termbank
>>> x,y,z = smt.Ints('x y z') >>> E = EGraph() >>> E.add_term(x + y) >>> assert E.in_terms(x) >>> assert not E.in_terms(z)
- Parameters:
t (ExprRef)
- Return type:
bool
- is_eq(t1: ExprRef, t2: ExprRef) bool
Check if two terms are equal in the EGraph.
>>> x,y,z = smt.Ints('x y z') >>> E = EGraph() >>> _ = E.union(x, y) >>> assert E.is_eq(x, y) >>> assert not E.is_eq(x, z)
- Parameters:
t1 (ExprRef)
t2 (ExprRef)
- Return type:
bool
- iter(vs: list[ExprRef])
- Parameters:
vs (list[ExprRef])
- reasons: dict[int, object]
- rebuild() list[tuple[ExprRef, ExprRef]]
>>> E = EGraph() >>> f = smt.Function('f', smt.IntSort(), smt.IntSort()) >>> x,y,z = smt.Ints('x y z') >>> E.add_term(f(x)) >>> E.add_term(f(y)) >>> _ = E.union(x,y) >>> assert E.find(f(x)) != E.find(f(y)) >>> E.rebuild() [(f(...), f(...))] >>> assert E.find(f(x)) == E.find(f(y))
- Return type:
list[tuple[ExprRef, ExprRef]]
- roots: defaultdict[SortRef, set[int]]
- rw(sorts: list[SortRef], f)
Bottom up ematch and rewrite. f should take one argumentsper each sort in sorts and return a tuple of two terms (lhs, rhs)
>>> E = EGraph() >>> f = smt.Function('f', smt.IntSort(), smt.IntSort()) >>> x,y,z = smt.Ints('x y z') >>> E.add_term(f(x)) >>> E.rw([smt.IntSort()], lambda v: (f(v), v + 1)) >>> assert E.find(f(x)) == E.find(x + 1)
- Parameters:
sorts (list[SortRef])
- simplify_terms()
Use built in simplifier to simplify all terms in the egraph. Similar to “extract and simplify”.
>>> E = EGraph() >>> x,y,z = smt.Ints('x y z') >>> E.add_term(4 + x + y + 7) >>> E.add_term(8 + x + y + 3) >>> E.simplify_terms() >>> assert E.find(8 + x + y + 3) == E.find(4 + x + y + 7)
- solver: Solver
- terms: dict[int, ExprRef]
- uf: dict[int, int]
- union(t1: ExprRef, t2: ExprRef, add=True, reason=None) bool
Assert equal two terms in the EGraph. Note that this does not add the terms to the EGraph.
>>> x,y,z = smt.Ints('x y z') >>> E = EGraph() >>> _ = E.union(x, y) >>> assert E.find(x) == E.find(y)
- Parameters:
t1 (ExprRef)
t2 (ExprRef)
- Return type:
bool