It’s interesting, while I was writing https://www.philipzucker.com/brute_eggmt/ I was hating the post, but immediately after dumping it out I began to appreciate that I had achieved something quite useful.

This is a continuation of that one. By the time I implement the e-graph and e-matching, I rarely (never?) get to the point where I implement extraction or proof production. Here we go.

The code for the post is here https://github.com/philzook58/knuckledragger/blob/0419f27bbc6cecbd5d1156a568cebd329dd5cbb0/kdrag/solvers/egraph.py#L1

Rewrites / Simplify

Another nice thing this style of egraph gets you is the ability to do regular rewrites / reuse existing rewriting mechanisms that work over terms. Z3 has a built in simplifier. I can use it by traversing the term bank. This is a relative of the “extract and simplify” technique of Koehler, Trinder, Steuwer. https://arxiv.org/abs/2111.13040 . I don’t extract because there are always terms lying around corresponding to the eclasses.

    def simplify_terms(self):
        """
        Use built in simplifier to simplify all terms in the egraph.
        Similar to "extract and simplify".

        >>> E = EGraph()
        >>> x,y,z = smt.Ints('x y z')
        >>> E.add_term(4 + x + y + 7)
        >>> E.add_term(8 + x + y + 3)
        >>> E.simplify_terms()
        >>> assert E.find(8 + x + y + 3) == E.find(4 + x + y + 7)
        """
        todo = []
        for t in self.terms.values():
            t1 = smt.simplify(t)
            if not t1.eq(t):
                todo.append((t, t1))
        for t, t1 in todo:
            self.add_term(t1)
            self._union(t, t1)

Extraction

Extraction is kind of the spice that has made egraphs interesting to a larger number of users, because it is the thing that changes it from a theorem proving method Term -> Term -> Bool to a simplification method Term -> Term.

The basic bottom up dynamic programming approach is to traverse all your enodes and add together the sum of the cost of your children with the cost of the enode funcdecl itself. Once this converges, you can extract out the term by following the best choice on downwards.

I implemented this as a total brute force loop.

class EGraph():
...
    def extract(self, t0: smt.ExprRef, cost_fun=(lambda _: 1)):
        """
        Extract a term from the egraph.

        >>> E = EGraph()
        >>> x,y,z = smt.Ints('x y z')
        >>> E.add_term(x + y)
        >>> E.rebuild()
        >>> E.extract(x + y)
        x + y
        >>> _ = E.union(x + y, y)
        >>> E.rebuild()
        >>> E.extract(x + y)
        y
        """
        inf = float("inf")
        best_cost = defaultdict(lambda: inf)
        best = {}
        while True:
            done = True
            # Terms are taking the place of enodes.
            for t in self.terms.values():
                eid = self.find(t)
                cost = cost_fun(t) + sum(
                    [best_cost[self.find(c)] for c in t.children()]
                )  # cost_fun(t.decl()) ?
                if cost < best_cost[eid]:
                    best_cost[eid] = cost
                    best[eid] = t
                    done = False
            if done:
                break

        # @functools.cache
        def build_best(t):
            t1 = best[self.find(t)]
            return t1.decl()(*[build_best(c) for c in t1.children()])

        return build_best(t0)
from kdrag.solvers.egraph import EGraph
from kdrag.all import *
a = smt.Int("a") 
# kd.prove((a * 2)/ 2 == a)
E = EGraph()
t = (a * 2) / 2
E.add_term(t)
E.simplify_terms()
print(E.extract(t))
E
a





EGraph(roots=defaultdict(<class 'set'>, {Int: {1929, 324, 41}}), terms={265: (a*2)/2, 149: a*2, 324: a, 41: 2, 1929: 2*a}, uf={265: 324, 149: 1929}, solver=[], reasons={})

Proofs

I have the opinion that what you mean by proof may mean different things. https://www.philipzucker.com/proof_objects/ Proofs objects can exist on a spectrum of “things to be filled in”. A bool is a proof object at the most insane end of the spectrum, but it isn’t at all crazy for proof objects to need some (hopefully polynomial and hopefully decidable) filling in.

A canonical example in my mind is that the proof object of the connectivity of two vertices in a graph is that path between them. This is quite literally the thing that a proof producing union find provides https://www.cs.upc.edu/~roberto/papers/rta05.pdf

One of the use cases of union finds is for fast connectivity querying in a graph (connectivity is an equivalence relation). It is used in Kruskal’s algorithm to find a minimum weight spanning tree https://en.wikipedia.org/wiki/Kruskal%27s_algorithm . It works by sorting the edges, and the adding them in order one by one into the union find. If one joins two previous unjoined components, it is part of the spanning tree. This spanning tree is the thing the proof producing union find holds. In equality saturation, you kind of want a least length proof, which means the least number of rule applications. This is basically what you get since you learn the rule applications in order.

As part of a general principle that traces = proof, a workable proof producing union find is to just log any union that was novel in a big list. You can then later post postprocess this spanning tree to find the appropriate path for two nodes of interesting.

uf = {}
def find(x):
    while x in uf:
        x = uf[x]
    return x
    
log = []
def union(x,y, reason=None):
    x1,y1 = find(x), find(y)
    if x1 != y1:
        uf[x1] = y1
        log.append((x,y,reason))

union(1,2, "for fun")
union(1,3, "by fiat")
union(2,3, "because I said so") # not logged because already connected
union(3,1, "to connect them")
print(uf)
print(log)

{1: 2, 2: 3}
[(1, 2, 'for fun'), (1, 3, 'by fiat')]

When you add rules and congruence closure into the mix, the reason field becomes more important because vertices aren’t atomic entities anymore. A Cong reason should hold the funcdecl applied and the children that are equal. A Rule reason data structure should hold the rule, the instantiations of the variables in the rule, and what term you think the left hand side is equal to perhaps, depending if you need to justify the existence of the term.

You may need to recursively understand the rewriting path that connects children of your term or know what made the rule you used fireable.

In my z3 based egraph, I have taken the principle to lean on z3 as much as possible. Z3 already has a proof producing union find. Z3 has proof terms, but also it has the notion of unsat cores, which are a fairly minimal number of asserted facts necessary to make a qeury unsatisfable. Using these, I can get a proof skeleton consisting of the unsat cores that justify at each level. Every node in the proof skeleton follows from it’s children via fairly straightforward quantifier free reasoning. The rules are dealt with externally and are logged with the rule that fired, the instantiation of the firing

class EGraph:
...
    reasons: dict[int, object]
...
    def union(self, t1: smt.ExprRef, t2: smt.ExprRef, reason=None) -> bool:
        """
        Assert equal two terms in the EGraph.
        Note that this does not add the terms to the EGraph.

        >>> x,y,z = smt.Ints('x y z')
        >>> E = EGraph()
        >>> _ = E.union(x, y)
        >>> assert E.find(x) == E.find(y)
        """
        if self._union(t1, t2):
            if self.proof:
                p = smt.FreshConst(smt.BoolSort())
                self.reasons[p.get_id()] = (t1, t2, reason)
                self.solver.assert_and_track(t1 == t2, p)
            else:
                self.solver.add(t1 == t2)
            return True
        else:
            return False


    def get_proof(self, t1: smt.ExprRef, t2: smt.ExprRef) -> list[object]:
        """
        Get the proof of why t1 == t2 in the egraph.
        The reasons returns may require recursive calls of get_proof.


        >>> E = EGraph(proof=True)
        >>> x,y,z = smt.Ints('x y z')
        >>> E.add_term(x + y)
        >>> _ = E.union(x + y, y, reason="because I said so")
        >>> _ = E.union(x + y, x, reason="because I said so too")
        >>> _ = E.union(x + y, z, reason="because I said so three")
        >>> E.get_proof(x, y)
        [(x + y, y, 'because I said so'), (x + y, x, 'because I said so too')]

        """
        self.solver.push()
        self.solver.add(t1 != t2)
        res = self.solver.check()
        assert res == smt.unsat
        cores = self.solver.unsat_core()
        self.solver.pop()
        return [self.reasons[p.get_id()] for p in cores]

Hmm. Does this actually work? I could get non well founded proof trees when we recurse through the rules (?). If I keep my assertions minimal this shouldn’t be the case? Hmm.

You may be unsatisfied that these aren’t fully explicit rewrite proofs. Well,

  • with theory reasoning, unless the theory is an equational theory, what can you expect?
  • I have taken you pretty close to constructing a rewrite looking proof yourself. You don’t trust z3? It’s probably your verifier anyway. Can recheck these proofs out in an independent SMT solver, and ITP tactics can probably handle this quantifier free reasoning without too much pain.

A natural thing for me to put in the reason field are Knuckledragger https://github.com/philzook58/knuckledragger proof objects and instantiation variables. I can then instantiate the knuckledragger proofs either eagerly or lazily. But you can use whatever reasons you want due to the wildness of python.

E = EGraph(proof=True)
x,y,z = smt.Ints("x y z")
f = smt.Function("f", smt.IntSort(), smt.IntSort())
t1 = f(f(f(x)))
E.add_term(t1)
E.union(f(x), x)
E.union(y, z)
E.union(f(x), x + y + z, )
E.get_proof(y + z, 0)
[(f(x), x, None), (f(x), x + y + z, None)]

Context

Z3 is intrinsically a more contextual engine than an egg-like implementation. We can assert equalities under assumptions using z3.Implies.

One way to turn the Z3 EMT into a contextual egraph is to just carry a list of contexts of interests. The idea of contexts is that sometimes you want to rewrite under If-then-elses in a way that only works under the assumptions you gain for being in that branch. For example If(x > 0, abs(x), 0) == If(x > 0, x, 0). See also datapath aware https://arxiv.org/pdf/2303.01839 and colored egraphs https://arxiv.org/abs/2305.19203

The contexts can be cleaned up for equivalence and subsumption during the rebuilding phase via seeing if solver.add(smt.Not(smt.Implies(ctx1,ctx2))) comes back unsat. The union find can be used for context equivalences, but a separate transitive structure is probably wise for subsumption. I don’t know how to make something all that better than brute force for maintaining an online transitive relation like subsumption. Surprisingly, I didn’t really need to retain the subsumption list. In ematching under context, I only return if there is a context that implies the discovered context in in_ctxs. What this means is that if I have And(p,q) in my contexts, the in_ctxs filter will pass the context p. In bottom up ematching, the implementation changes very little for ematching under context. It is interesting that adding contexts does not combinatorially explode bottom up ematching. The loop size remains largely dependent on the number of variables in your pattern whereas top down matching probably requires a scan over contexts.

One probably does not want to let contexts explode.

I think I need to modify extract to also build for each context of interest and pull from these tables when you go under an If or lazy And Or. Another reason to maybe do extraction top down.

from kdrag.all import *
def subsumes(ctx1, ctx2): # __le__ for contexts
    s = smt.Solver()
    s.add(smt.Not(smt.Implies(ctx1, ctx2)))
    return s.check() == smt.unsat

p,q,r,s = smt.Bools("p q r s")
assert subsumes(p, p)
assert not subsumes(p, q)
assert subsumes(smt.And(p, q), p) # p is stronger assumption than p and q
assert not subsumes(p, smt.And(p, q))
assert not subsumes(smt.And(p, q, s), smt.And(p, q, r)) # incomparable contexts
from kdrag.all import *
from kdrag.solvers.egraph import EGraph
from dataclasses import dataclass
import itertools
@dataclass
class CtxEGraph(EGraph):
    ctxs : set[int]
    def __init__(self):
        super().__init__()
        self.ctxs = set()
    def add_ctx(self, ctx : smt.BoolRef):
        assert isinstance(ctx, smt.BoolRef), "Context must be a BoolRef"
        self.ctxs.add(ctx.get_id())
        self.add_term(ctx)
    def ctx_union(self, ctx : smt.BoolRef, t1 : smt.ExprRef, t2 : smt.ExprRef):
        with self.solver:
            self.solver.add(smt.Not(ctx))
            res = self.solver.check()
        if res == smt.unsat:
            return self.union(t1, t2)
        else:
            with self.solver:
                self.solver.add(smt.Not(smt.Implies(ctx, t1 == t2)))
                res = self.solver.check()
            if res == smt.unsat:
                return False
            else:
                ctx = self.terms[self.find(ctx)]
                self.solver.add(smt.Implies(ctx, t1 == t2))
                return True
    def subsumes(self, ctx1 : smt.BoolRef, ctx2 : smt.BoolRef):
        with self.solver:
            self.solver.add(smt.Not(smt.Implies(ctx1, ctx2)))
            return self.solver.check() == smt.unsat
    def in_ctxs(self, ctx: smt.BoolRef) -> bool:
        with self.solver:
            self.solver.add(smt.Not(smt.Or(smt.Implies(self.terms[ctxid], ctx) for ctxid in self.ctxs)))
            return self.solver.check() == smt.unsat
    def ematch_ctx(
        self, vs: list[smt.ExprRef], ctxpat : smt.BoolRef, pat: smt.ExprRef
    ) -> list[list[smt.ExprRef]]:
        res = []
        for eids in itertools.product(*[self.roots[v.sort()] for v in vs]):
            ts = [self.terms[eid] for eid in eids]
            lhs = smt.substitute(pat, *zip(vs, ts))
            ctx = smt.substitute(ctxpat, *zip(vs, ts))
            if self.in_ctxs(ctx) and self.in_terms(lhs):
                res.append(ts)
        return res
    def rebuild(self):
        super().rebuild()
        self.ctxs = {self._find(ctxid) for ctxid in self.ctxs}


E = CtxEGraph()
x,y,z = smt.Reals("x y z")
t = smt.If(x > 0, smt.Sqrt(x**2), 3 + x + 4)
E.add_term(t)
E.simplify_terms()
p,q,r = smt.Bools("p q r")

for ctx,_,_ in E.ematch([p,y,z], smt.If(p, y, z)):
    E.add_ctx(ctx)
E.rebuild()

for c,q in E.ematch_ctx([ctx, y], ctx, smt.Sqrt(y**2)):
    E.ctx_union(c, smt.Sqrt(q**2), q)
E.rebuild()
# Ugh, extract has to be context aware.
# I'm cheating here.
E.extract(t)



If(x > 0, x, 7 + x)

Bits and Bobbles

Next time: Lambdas. Z3 does kind of sort of support lambdas. Slotted egraphs I think can be emulated by going to an eta-maximal form. Z3 simplify already evals lambdas. Nice.

Concerns: Ematching modulo theories might make WAY more matches and not intended ones. It more be a nothc more declarative and hence harder to control. Also extraction only finds things you have reified terms of. There’s no free lunch.

Pavel released an interesting blog post with an alternative to bottom up matching https://pavpanchekha.com/blog/egraph-t.html

The single biggest mundane insight I’ve had with knuckledragger https://github.com/philzook58/knuckledragger is that centralizing over the z3 ast is a really good idea (or at least an idea very much to my taste). At a certain point I was centralizing my egraphs modulo theories stuff in its own repo https://github.com/philzook58/eggmt . I was going to build out a clone of the z3 data structures and maybe I still might have to. Z3’s ast is not extensible. However, the style I made in my recent post uses the z3 interned ids with python dictionaries to other bits. This is straightforward and extensible enough to probably work, using z3 as my common blackboard.

#smt.BitVec("a", 64)
#kd.prove((a * 2) / 2 ==  (a << 1) / 2)
#kd.prove(smt.Implies(smt.And(a < 2**100, a > 0), (a * 2) / 2 ==  a))

Contextual union find

A persistent union find vs a colored. Colored mantains many and the other unions finds inherit from their parent union finds. The persistent union find lives in many worlds.

In a contextual union find , the “color” keys themselves are union finds. You make make a union find into a key by fully canonizing/compressing it and then use the usual techniques to turn maps/dictionaries into keys. Patricia trie or sorted assoc list.

This inheritance structure is interesting.

from dataclasses import dataclass, field

@dataclass
class UF():
    uf : list[int] = field(default_factory=list)
    def makeset(self):
        eid = len(self.uf)
        self.uf.append(eid)
        return eid
    def find(self, x):
        while self.uf[x] != x:
            x = self.uf[x]
        return x
    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            self.uf[x] = y
        return y
    def rebuild(self):
        for i in range(len(self.uf)):
            self.uf[i] = self.find(i)

uf = UF()
for i in range(10):
    uf.makeset()
uf.union(0, 1)
uf.union(1, 2)
uf.rebuild()
uf
UF(uf=[2, 2, 2, 3, 4, 5, 6, 7, 8, 9])

A persistent union find retains all old versions of the union find

a colored / linked / refiniing / hiearchical union find has multiple refining union finds hanging around. Unions higher in the herarchy are inherited by ufs lower in the hierarchy.

A contextual union find is like this with the added feature that the ufs are identified with contexts. When contexts become equivalent, the appropriate unuion finds should become equivalent. When a context is subsumed,it should inherit the equality infromation of the subsumee.

Subsumption is a better mechanism than lattices for contexts. While there is a lattice, it’s a weird one. Subsumption is more of a direct translation

Equating two nodes in a colored union find makes a new node? Multiparent. Maybe you have to ping pong until convergence. I think the contracting nature of the maps means this’ll work. You only have to ping pong up to common ancestor. Equating two color nodes goes into a meta-union find or color nodes? You need to ping pong through the entire tree, or normalize everything in the root node. a meta meta union find? Really you can assert node1 <= node2.

class DiffUF(): ufs: list[UF] self.duf

class DiffUF(): parent | self.uf, self.duf # either points to equivalent parent UF or actually has the data and pointer to inehritance uf.

class DiffUF():
    def __init__(self, uf : UF):
        self.uf = uf
        self.duf = {}
    def makeset(self):
        return self.uf.makeset()
    def find(self, x):
        x = self.uf.find(x)
        while x in self.duf:
            x = self.duf[x]
        return x
    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            self.duf[x] = y
        return y
    def __hash__(self):
        return hash(tuple(sorted(self.duf.items())))
    def __le__(self, other):
        if len(self.duf) > len(other.duf): # fast check
            return False
        return all(other.find(k) == other.find(v) for k,v in self.duf.items())
    def __eq__(self, other):
        return len(self.duf) == len(other.duf) and self <= other # ? Right?
    def __or__(self, other): # over same base
        """
        Union find join. Finest partition greater than both.
        """
        assert id(self.uf) == id(other.uf)
        new = DiffUF(self.uf)
        new.duf = self.duf.copy() # or pick biggest one
        for k,v in other.duf.items():
            new.union(k,v)
        return new
    def __and__(self, other): # over same base
        """union find meet. Coarsest partition less than both."""
        assert id(self.uf) == id(other.uf)
        new = DiffUF(self.uf)
        for k,v in self.duf.items():
            if other.find(k) == other.find(v):
                new.union(k,v)
        return new
    def rebuild(self): # canon
        old_duf = {}
        self.duf = {}
        self.uf.rebuild()
        for k,v in old_duf.items():
            self.union(k,v)
    
    

https://usr.lmf.cnrs.fr/~jcf/publis/puf-wml07.pdf A Persistent Union-Find Data Structure

Proof producing

  1. Tie braking
  2. Path compress or no
  3. edge storage, attributed storage
  4. Pointers, array, or dict
  5. Lazy vs eager

A purely function union find using purely functional distionaries

Union finds solve connectivity in graphs. Proof producing UF stores a spanning tree also so that it can return a path when you ask for it.

Layered union finds are like layered theories M1 |= M2 |= M3 |= M4 ...

A canonizer uf would want to update all the children to be updated too.

@dataclass
class EagerUF():
    uf : list[int] = field(default_factory=list)
    def makeset(self):
        eid = len(self.uf)
        self.uf.append(eid)
        return eid
    def find(self, x):
        return self.uf[x]
    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            self.uf[x] = y
            self.rebuild()
        return y
    def rebuild(self):
        for i in range(len(self.uf)):
            self.uf[i] = self.find(i)
class FrozenDict()

You don’t have to structurally canonicalize to implement a correct hash. But you kind of might as well

@dataclass(frozen=True)
class CanonUF():
    uf : tuple[int,...]
    @classmethod
    def create(cls):
        return CanonUF(())
    def makeset(self):
        eid = len(self.uf)
        self.uf.append(eid)
        return CanonUF(self.uf + (eid,))
    def find(self, x):
        return self.uf[x]
    def union(self, x, y):
        x,y = self.find(x), self.find(y)
        if x != y:
            self.uf[x] = y
            self.rebuild() # rebuild immeidately
        return y
    def __hash__(self):
        return hash(self.uf)
    def __eq__(self, other):
    def rebuild(self):
        for i in range(len(self.uf)):
            self.uf[i] = self.find(i)
class ContextUF():
    base : UF
    ctxs : dict[CanonUF, DiffUF] #  ctx |- res 
    def __init__(self):
        self.base = UF()
        self.ctxs = {}
    def rebuild(self):
        newctx = {}
        for ctx, uf in self.ctxs.items():
            ctx1 = ctx.rebuild()
            for ctx1, uf1 in self.ctxs.items():
                if ctx <= ctx1:
                    uf1.merge(uf)
            newctx[ctx.rebuild()] = uf.rebuild()
        self.base.rebuild()



Persistent

https://github.com/MagicStack/immutables https://github.com/tobgu/pyrsistent https://discuss.python.org/t/pep-603-adding-a-frozenmap-type-to-collections/2318/219

! python3 -m pip install immutables
Collecting immutables
  Downloading immutables-0.21-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.5 kB)
Downloading immutables-0.21-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (104 kB)
Installing collected packages: immutables
Successfully installed immutables-0.21

Bits and Bobbles

I called it base. Are fibers leaking into my thinking?

I have meet and join of uf. A heyting algerba if I can define an A => B? Probably the finest partition that when merged with A gives B or is less than B? C /\ A <= B

def diff(self, other):

partition refinement https://en.wikipedia.org/wiki/Partition_of_a_set

old


Simple union find

compressing

colored union finds - use min to ping pong. label the union find. DAG hierarchy is fine.

union finds as converging functions

proof pdoucting union find

using dictionary vs using ids vs using pointers

The z3 egraph and doubly linked lists. If we want to retain down pointers it is abboying because there mightb e multiple children to one parent. But you can insert yourself into a doubly linked list via the dasncing link technique. Hmm. Maybe this is why z3 does it this way. For fast backtracking https://z3prover.github.io/papers/z3internals.html#sec-equality-and-uninterpreted-functions

from dataclasses import dataclass
class UF():
    uf : list[int]
    def find(self, x):
        while x != self.uf[x]:
            x = self.uf[x]
        return self.uf[x]
    def makeset(self):
        self.uf.append(len(self.uf))
        return len(self.uf) - 1
    def union(self, x, y):
        self.uf[self.find(x)] = self.find(y)
        return self.find(y)

uf = UF()
a, b, c = uf.makeset()


---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

Cell In[1], line 16
     13         return self.find(y)
     15 uf = UF()
---> 16 uf.ma


AttributeError: 'UF' object has no attribute 'ma'
from dataclasses import dataclass
@dataclass
class Cell():
    name: str
    id: int
    parent: 'Cell'
    def __init__(self, name="_"):
        self.parent = self
        self.id = id(self)
        self.name = name
    def find(self):
        x = self
        while x.parent is not x:
            x = x.parent
        return x
    def union(self,y):
        self.find().parent = y.find()
        return y.find()
    
x = Cell("x")
y = Cell("y")
z = Cell("z")
print(x.parent)
x.union(y)
y.union(z)
print(x.find() == z.find())
print(x)
print(z)
Cell(name='x', id=136449433498864, parent=...)
True
Cell(name='x', id=136449433498864, parent=Cell(name='y', id=136449433498432, parent=Cell(name='z', id=136449433492432, parent=...)))
Cell(name='z', id=136449433492432, parent=...)
reasons = []
trace = []
def union_reason(x, y, reason):
    reasons[find(x)] = (x, y, reason)
    trace.append((x, y, reason)) # this is sufficient. we don't need to store find(x) and find(y)
    uf[find(x)] = find(y)
    return find(y)

def explain(x,y):
    

Visualizing as a graph. The union find is part of kruskal’s algorithm https://en.wikipedia.org/wiki/Kruskal%27s_algorithm

So for example if you had a bunch of equalities and you know how painful each one was to get, you could devise a minimum spanning tree for that.

Term rewriting as a graph.

Secret congruence edges

import networkx as nx

# random graph with multiple components
G = nx.Graph()
#G.add_nodes_from(range(10))



# color the edges in the union find
# vectorized normalization.
# For egraph purposes, not being fully normalized isn't really a problem.

import numpy as np

uf = np.arange(10)
uf[8] = 0
uf[0] = 4

def normstep(uf):
    return uf[uf] 

normstep(uf)

def step2(uf):
    return uf[uf[uf]]
step2(uf)

array([4, 1, 2, 3, 4, 5, 6, 7, 4, 9])