kdrag.solvers.egraph

Classes

EGraph([proof])

class kdrag.solvers.egraph.EGraph(proof=False)

Bases: object

add_term(t: ExprRef) None

Recursively add term to egraph

Parameters:

t (ExprRef)

Return type:

None

copy()

Copy the egraph. This is a shallow copy, so the terms are not copied.

>>> E = EGraph()
>>> f = smt.Function('f', smt.IntSort(), smt.IntSort())
>>> x,y,z = smt.Ints('x y z')
>>> E.add_term(f(x))
>>> E.add_term(f(y))
>>> _ = E.union(x,y)
>>> assert E.find(f(x)) != E.find(f(y))
>>> E2 = E.copy()
>>> _ = E2.rebuild()
>>> assert E2.find(f(x)) == E2.find(f(y))
>>> assert E.find(f(x)) != E.find(f(y))
dot(filename: str = 'egraph') Digraph

Create graphviz representation of the egraph.

>>> E = EGraph()
>>> x,y,z = smt.Ints("x y z")
>>> E.add_term(x + y)
>>> E.union(y,z)
True
>>> _ = E.rebuild()
>>> _ = E.dot()
Parameters:

filename (str)

Return type:

Digraph

eclasses() defaultdict[int, defaultdict[FuncDeclRef, set[tuple[int]]]]

Returns a dictionary mapping each term to its equivalence class.

>>> E = EGraph()
>>> x,y,z = smt.Ints("x y z")
>>> E.add_term(x + y)
>>> E.union(y,z)
True
>>> _ = E.rebuild()
>>> _ = E.eclasses()
Return type:

defaultdict[int, defaultdict[FuncDeclRef, set[tuple[int]]]]

ematch(vs: list[ExprRef], pat: ExprRef) list[list[ExprRef]]

Find all matches of pat in the egraph.

>>> E = EGraph()
>>> f = smt.Function('f', smt.IntSort(), smt.IntSort())
>>> x,y,z = smt.Ints('x y z')
>>> E.add_term(f(x))
>>> _ = E.union(f(x), x)
>>> _ = E.rebuild()
>>> E.ematch([y], f(f(y)))
[[x]]
Parameters:
  • vs (list[ExprRef])

  • pat (ExprRef)

Return type:

list[list[ExprRef]]

extract(t0: ~z3.z3.ExprRef, cost_fun=<function EGraph.<lambda>>)

Extract a term from the egraph.

>>> E = EGraph()
>>> x,y,z = smt.Ints('x y z')
>>> E.add_term(x + y)
>>> _ = E.rebuild()
>>> E.extract(x + y)
x + y
>>> _ = E.union(x + y, y)
>>> _ = E.rebuild()
>>> E.extract(x + y)
y
Parameters:

t0 (ExprRef)

find(t: ExprRef) int

Get canonical id of term in egraph.

Parameters:

t (ExprRef)

Return type:

int

get_proof(t1: ExprRef, t2: ExprRef) list[object]

Get the proof of why t1 == t2 in the egraph. The reasons returns may require recursive calls of get_proof.

>>> E = EGraph(proof=True)
>>> x,y,z = smt.Ints('x y z')
>>> E.add_term(x + y)
>>> _ = E.union(x + y, y, reason="because I said so")
>>> _ = E.union(x + y, x, reason="because I said so too")
>>> _ = E.union(x + y, z, reason="because I said so three")
>>> list(sorted(E.get_proof(x, y), key=lambda x: x[2]))
[(x + y, y, 'because I said so'), (x + y, x, 'because I said so too')]
Parameters:
  • t1 (ExprRef)

  • t2 (ExprRef)

Return type:

list[object]

in_terms(t: ExprRef) bool

Semantically check if t is in termbank

>>> x,y,z = smt.Ints('x y z')
>>> E = EGraph()
>>> E.add_term(x + y)
>>> assert E.in_terms(x)
>>> assert not E.in_terms(z)
Parameters:

t (ExprRef)

Return type:

bool

is_eq(t1: ExprRef, t2: ExprRef) bool

Check if two terms are equal in the EGraph.

>>> x,y,z = smt.Ints('x y z')
>>> E = EGraph()
>>> _ = E.union(x, y)
>>> assert E.is_eq(x, y)
>>> assert not E.is_eq(x, z)
Parameters:
  • t1 (ExprRef)

  • t2 (ExprRef)

Return type:

bool

iter(vs: list[ExprRef])
Parameters:

vs (list[ExprRef])

reasons: dict[int, object]
rebuild() list[tuple[ExprRef, ExprRef]]
>>> E = EGraph()
>>> f = smt.Function('f', smt.IntSort(), smt.IntSort())
>>> x,y,z = smt.Ints('x y z')
>>> E.add_term(f(x))
>>> E.add_term(f(y))
>>> _ = E.union(x,y)
>>> assert E.find(f(x)) != E.find(f(y))
>>> E.rebuild()
[(f(...), f(...))]
>>> assert E.find(f(x)) == E.find(f(y))
Return type:

list[tuple[ExprRef, ExprRef]]

roots: defaultdict[SortRef, set[int]]
rw(sorts: list[SortRef], f)

Bottom up ematch and rewrite. f should take one argumentsper each sort in sorts and return a tuple of two terms (lhs, rhs)

>>> E = EGraph()
>>> f = smt.Function('f', smt.IntSort(), smt.IntSort())
>>> x,y,z = smt.Ints('x y z')
>>> E.add_term(f(x))
>>> E.rw([smt.IntSort()], lambda v: (f(v), v + 1))
>>> assert E.find(f(x)) == E.find(x + 1)
Parameters:

sorts (list[SortRef])

simplify_terms()

Use built in simplifier to simplify all terms in the egraph. Similar to “extract and simplify”.

>>> E = EGraph()
>>> x,y,z = smt.Ints('x y z')
>>> E.add_term(4 + x + y + 7)
>>> E.add_term(8 + x + y + 3)
>>> E.simplify_terms()
>>> assert E.find(8 + x + y + 3) == E.find(4 + x + y + 7)
solver: Solver
terms: dict[int, ExprRef]
uf: dict[int, int]
union(t1: ExprRef, t2: ExprRef, add=True, reason=None) bool

Assert equal two terms in the EGraph. Note that this does not add the terms to the EGraph.

>>> x,y,z = smt.Ints('x y z')
>>> E = EGraph()
>>> _ = E.union(x, y)
>>> assert E.find(x) == E.find(y)
Parameters:
  • t1 (ExprRef)

  • t2 (ExprRef)

Return type:

bool