Contextual Union Finds
Something that is desired in egraph rewriting is rewriting under assumptions.
The canonical example of this is writing inside the branches of an if-then-else if x = y and y != 1 then x/y else x+2. Obviously in one branch we know that x/y can be reduce to the constant 1. However, we do not know that x=y globally. Another case that Eytan showed was max(x,y) - min(x,y) = abs(x - y) where we may want to split into x > y and x <= y.
You may also want assumptions to just see if you can make progress in an expression and then output those assumptions later (people do this sort of thing to try and simplify traces from a symbolic executor by assuming non aliasing of addresses). Or for assumptions in the subcases of a inductive proof.
The technique of assume nodes https://arxiv.org/pdf/2303.01839 gives a way to encode this into an egraph rewriting system. The colored egraph work https://arxiv.org/abs/2305.19203 tries to bake this in.
As is often the case, I think there is a lot to be learned by taking a step back to look at the simpler case of a contextual union find. I think it’s actually fairly straightforward (ignoring that I’ve done it wrong and not known it for years). The colored egraph paper mentions but does not go into much detail about the union find.
The basic idea is to maintain a hierarchy of union finds. Unions asserted into the child union finds should not mutate the parents, but find operations inside the child may have to find inside the parent.
If you can assume the parent union find stays fixed, that simplifies things. Then a persistent union find https://usr.lmf.cnrs.fr/~jcf/publis/puf-wml07.pdf (or a union find using persistent hash maps) may be acceptable. This is the sort of thing that occurs in a backtracking solver.
But we basically want to assume the case where new equalities are being discovered both in the global union find and in the child union finds and then the children receive updates from the parent, but not vice versa. This feels natural in a saturating solver like equality saturation
Sparse and Dense Unions Finds
There are at least 3 flavors of union find. One flavor uses refcells, another uses a vector arena, and a third uses hashmaps.
I like the latter 2 more because it gives you a handle on the entire union find as a single entity, which can be useful for sweeping if need be.
Dense and Vectory
This is the vector arena style. It’s nice that it only requires a vector and hence has fast lookup. A root of the union find is represented by a self reference loop. One could also use None.
from dataclasses import dataclass, field
@dataclass
class UFArena():
parents : list[int] = field(default_factory=list)
def makeset(self):
eid = len(self.parents)
self.parents.append(eid)
return eid
def find(self, x : int):
while self.parents[x] != x:
x = self.parents[x]
return x
def union(self, x : int, y : int):
x,y = self.find(x), self.find(y)
if x != y:
if x < y:
x,y = y,x
self.parents[x] = y
return y
def rebuild(self):
for i in len(self.parents):
self.parents[i] = self.find(i)
uf = UFArena()
x,y,z = [uf.makeset() for i in range(3)]
uf.union(x,y)
uf
UFArena(parents=[0, 0, 2])
Sparse and HashMappy
This is a different style. The vector above is in a sense being used as dict[int,int]. What is nice about the hashmap style is that it is more space efficient if you have very sparse unions, and also that it supports arbitrary hashable objects as keys. Roots are represented by not being a key in the hashmap.
@dataclass
class UFDict():
uf : dict[object,object] = field(default_factory=dict)
def find(self, x):
while x in self.uf:
x = self.uf[x]
return x
def union(self, x, y):
x,y = self.find(x), self.find(y)
if x != y:
y,x = min(x,y), max(x,y)
self.uf[x] = y
return y
def rebuild(self):
for k in self.uf.keys():
self.uf[k] = self.find(k)
def items(self):
return self.uf.items()
uf = UFDict()
uf.union(0,1)
uf.union(0,2)
uf
UFDict(uf={1: 0, 2: 0})
Opaque Context (Color)
The idea is to have a big union find and a smaller derived child union finds below it. You can choose to union into the big boy, in which case all the little guys inherit those unions, or union into the little guys, in which case the big boy is unchanged.
This is trickier to get right than you might think. I have thought this made sense for years until I actually went to implement it and saw it gives wrong results. I think what I have now makes sense.
You might think you can just do parent.find(child.find(x)), but this gives false negatives (doesn’t normalize things that should be equal to equal ids). A counterexample is in the next section. Likewise for child.find(parent.find(x)). And likewise for fixed pointing between the two (I’m not even sure this is guaranteed to converge)
# an incorrect fixpoint loop
while True:
y = child.find(x)
y = parent.find(y)
if x == y:
return x
y = x
Likewise even for searching up the union find tree rather than destructively looking up it. Likewise for avoiding path compression.
I think a version that works is to maintain the ability to traverse the classes in the little guy and call bigboy.find on all of them and take the minimum. This is more expensive than a simple union find, but what ya gonna do.
def find(x):
return min(parent.find(y) for y in child.eqset(x))
Note that it may be the case (and it often is) that false negatives are acceptable. In equality saturation, failure to see two things are equal may mean we just haven’t discovered they are equal yet. Confirmed disequality is a separate mechanism https://dl.acm.org/doi/10.1145/3704913 . If that is your case maybe just use one of the strategies above. It may also be that the eqaulity saturation rewrite rules may paper over any false negative issues, in that they may keep reasserting unions until you reach some correct fixed point.
Example Issue
As an example of the problem, consider a starting state of
[0,1,2,3]
{}
We then receive 1=2=3 in the context uf
[0,1,2,3]
{3 : 2, 2 : 1}
If we then receive union(0,2) in the parent context we get to
[0,1,0,3]
{3 : 2, 2 : 1}
Now if we eagerly take the lower union find on find(3) we get 3 -> 2 -> 1, which misses the pathway from 2 -> 0 in the parent union find.
I believe this sort of problem can be cooked up for any fixed scam of bouncing around between the parent and children union find.
An Implementation
The set of ids in an equivalence class is kind of like a semigroup or “analysis” or lattice result-ish in that you can merge them when you call union. This is what you need when you have an data keyed on by an eid.
This set can also be maintained by splicing together linked lists https://z3prover.github.io/papers/z3internals.html#sec-equality-and-uninterpreted-functions (in the form of a “sibling” arena here sibling : list[int] https://mastodon.gamedev.place/@harold/114599334531897851 ), or by using a unode tree like in the aegraph. Using a python set is the simplest. This set can also be reduced with respect to the parent union find as part of a compression and/or rebuilding.
from collections import defaultdict
@dataclass
class UFContext():
parentuf : UFArena
uf : dict[object,object] = field(default_factory=dict)
# Could also use linked list based or tree based enumerator
ids : dict[object, set[object]] = field(default_factory=lambda: defaultdict(set))
def makeset(self):
x = self.parentuf.makeset()
return x
def find(self, x):
while x in self.uf:
x = self.uf[x]
if x not in self.ids:
return self.parentuf.find(x)
# We could compress ys with respect to parents
else:
ys = self.ids[x]
return min(min(self.parentuf.find(y) for y in ys), self.parentuf.find(x))
def union(self, x, y):
x,y = self.find(x), self.find(y)
if x != y:
y,x = min(x,y), max(x,y)
self.uf[x] = y
self.ids[y] |= self.ids[x]
self.ids[y].add(x)
return y
def rebuild(self):
for k in self.uf.keys():
self.uf[k] = self.find(k)
uf0 = UFArena()
uf1 = UFContext(uf0)
x,y,z,w = [uf1.makeset() for i in range(4)]
uf1
uf1.union(y,z)
uf1.union(z,w)
uf1
uf0.union(x,z)
uf1
uf1.find(x)
assert uf1.find(w) == uf1.find(z) # uh oh!
uf1
UFContext(parentuf=UFArena(parents=[0, 1, 0, 3]), uf={2: 1, 3: 1}, ids=defaultdict(<class 'set'>, {1: {2, 3}, 2: set(), 3: set()}))
Structural Canonization of Union Finds for Keys
What is kind of nice though is not to have opaque contexts. We want contexts that are labelled by what we assumed to make them.
We want to talk about atomic contextual equations like {a = b, c = d} |= a = c, where {a = b, c = d} is the label of the context. If we ever produce two labels that are the same by different means, like a = b, b = c and b = a, c = a, we want the thing to know that.
A general strategy for hashing / indexing this kind of thing is to structurally canonize the object in question https://www.philipzucker.com/hashing-modulo/ . Sets can be canonized for example as a sorted and deduped list.
The simple structurally canonical form of an equivalence relation is to
- make a union find with a deterministic parent method (here I use minimum eid to tie break who becomes parent of whom)
- fully compress it
- Use a structurally canonical form of dictionaries, here a sorted and deduped association list.
We can then have a contextual union find that contains a biguf and lots of little context_ufs keyed by their canonical union find.
Conceptually I think this is a chain of union finds big -> key -> context_uf. I insert the key into the context_uf immediately upon construction.
The keys of this context_uf can grow stale if more unions happen in biguf. If so, a rebuilding may discover that two keys actually need to be merged. This does not deal with the case that one context may subsume another which has to be dealt with separately in rebuilding.
from collections import defaultdict
type CanonUF = object
@dataclass
class UFContextKeyed():
biguf : UFArena = field(default_factory=UFArena)
context_ufs : dict[CanonUF, UFContext] = field(default_factory=dict)
def makeset(self):
x = self.biguf.makeset()
return x
def make_key(self, *eqs):
uf = UFDict()
for l,r in eqs:
uf.union(self.biguf.find(l), self.biguf.find(r))
uf.rebuild()
print(uf)
return tuple(sorted(uf.items()))
def make_context(self, *eqs):
key = self.make_key(*eqs)
uf = self.context_ufs.get(key)
if uf is None:
uf = UFContext(self.biguf)
for l,r in eqs:
uf.union(l,r)
self.context_ufs[key] = uf
return key, uf
else:
return key, uf
def find(self, ctx, x):
return self.context_ufs[ctx].find(x)
def union(self, x, y, ctx=None):
if ctx is None:
return self.biguf.union(x,y)
else:
return self.context_ufs[ctx].union(x,y)
def rebuild(self):
# rebuild keys merge on key collision
ufnew = UFContextKeyed()
ufnew.biguf = self.biguf
for k, uf0 in self.context_uf.items():
# recanonize key wrt biguf
knew, uf = ufnew.make_context(*k)
for x,y in uf0.items():
uf.union(x,y)
return ufnew
uf = UFContextKeyed()
x,y,z,w = [uf.makeset() for _ in range(4)]
key, uf1 = uf.make_context((x,y))
uf.union(y, z, ctx=key)
uf
UFDict(uf={1: 0})
UFContextKeyed(biguf=UFArena(parents=[0, 1, 2, 3]), context_ufs={((1, 0),): UFContext(parentuf=UFArena(parents=[0, 1, 2, 3]), uf={1: 0, 2: 0}, ids=defaultdict(<class 'set'>, {0: {1, 2}, 1: set(), 2: set()}))})
Bits and Bobbles
For more on union find variations: https://www.philipzucker.com/prim_level_uf/
I think what I call a “levelled” union find or “scoped” union find is distinct from the above. The intent there is that the “little guy” intentionally infects the big guy. Tie breaking is controlled such that it is always ok to clean up the little guy at any point but still retain the implied transitive equalities to ids that belong to the big guy. In the levelled union find, the level is an intrinsic part of the eid. You may or may not choose to do things in this way in the contextual union find. One can choose to alway makeset new identifiers that are immediately unioned to old identifiers.
Likewise, I think “inequality union finds” are a separate notion. https://www.philipzucker.com/asymmetric_complete/ These support notions of refinement. There is some relationship here though in that different contexts may be in a subsumption relationship to each other. p |= a = b q |= a = c p => q. If p described exactly the conditions where you expressions are well defined, that would look a lot like refinement.
It is also possible to have a multi layer hiearchy or derived union finds in the form of tree or dag (?) You need to maintain the set of eids to look through at each layer you move up.
After this went out another design occurred to me. What one could do is just inform all the child union finds upon an union given the the parent. This is tricky enough I’m not sure it’s right.
class ContextUF():
biguf : UFArena
context_ufs : dict[object, UFDict]
def find(self, x, ctx=None):
if ctx is None:
return self.biguf.find(x)
else:
self.biguf.find(self.context_ufs[ctx].find(x))
def union(self, x, y, ctx=None):
if ctx is None:
x1,y1 = self.biguf.find(x), self.biguf.find(y)
if x1 != y1:
#inform everyone. O(n)ish in number of contexts
self.biguf.union(x,y)
for uf in self.context_ufs.values():
x,y = self.biguf.find(uf.find(x)), self.biguf.find(uf.find(y))
# Fix confluence failure?
if x != y:
uf.union(x,y)
uf.union(x1,y1)
uf.union(x, x1)
The contextual union find Show the counterexamples Normalizing union finds as keys.
https://github.com/eytans/easter-egg/blob/master/src/colors.rs
The basic structure of a single context.
You unfortunately do need to do some search like stuff in find if you want to avoid false negatives. If is possible for the parent union find to receive an update such that any strategy of eagerly finding and bouncing around between the parent and child union find misses the pathway to the truly canonical node.
Having said all that, with path compression, the search doesn’t have to be paid over and over, so maybe it’s not all bad
This is a microcosm of theory combination actually, in that it is harder to combine (union) rewrite rules sets than you might think. Rebuilding is running completion again. Min is a mutually compatible ordering.
I need to maintain just an explicit enumeration? This sucks?
Two Failures
At some point I thought just calling find on the child uf and then the parent uf would work. It does not.
Then I thought a form of search during find might be sufficient. It is not. This search isn’t really that much better than maintaining the eclass set anyway.
I dunno. I may be missing something nice to do. If you figure it out please do tell. I will tell you that some things you try that feel intuitively fine are wrong.
The following is wrong
@dataclass
class UFContext():
parentuf : UFArena
uf : dict[object,object] = field(default_factory=dict)
def makeset(self):
return self.parentuf.makeset()
def find(self, x):
seen = set([x])
todo = [x]
while todo:
x = todo.pop()
y = self.uf.get(x)
if y is not None and y not in seen:
seen.add(y)
todo.append(y)
y = self.parentuf.parents[x]
if y != x and y not in seen:
seen.add(y)
todo.append(y)
y = min(seen)
#for x in seen: # might as well path compress
# if y != x:
# self.uf[x] = y
return y
def union(self, x, y):
x,y = self.find(x), self.find(y)
if x != y:
y,x = min(x,y), max(x,y)
self.uf[x] = y
return y
def rebuild(self):
for k in self.uf.keys():
self.uf[k] = self.find(k)
uf0 = UFArena()
x,y,z = [uf0.makeset() for i in range(3)]
uf1 = UFContext(uf0)
uf2 = UFContext(uf0)
uf1.union(x,y)
uf1
assert uf1.find(x) == uf1.find(y)
assert uf0.find(x) != uf0.find(y)
assert uf2.find(x) != uf2.find(y)
uf0.union(y,z) # contexts inherit
assert uf2.find(y) == uf2.find(z)
assert uf1.find(x) == uf1.find(z)
uf1
UFContext(parentuf=UFArena(parents=[0, 1, 1]), uf={1: 0})
uf0 = UFArena()
x,y,z,w = [uf0.makeset() for i in range(4)]
uf1 = UFContext(uf0)
uf1.union(y,z)
uf1.union(z,w)
uf1
UFContext(parentuf=UFArena(parents=[0, 1, 2, 3]), uf={2: 1, 3: 1})
uf0.union(x,z)
uf1
uf1.find(x)
uf1.find(w) # uh oh!
uf1
UFContext(parentuf=UFArena(parents=[0, 1, 0, 3]), uf={2: 1, 3: 1})
uf1.rebuild()
uf1.rebuild()
uf1.find(w)
uf1
UFContext(parentuf=UFArena(parents=[0, 1, 0, 3]), uf={2: 0, 3: 1})
A version I thought made sense was a DeltaUF that just found in self, then parent. This misses stuff.
@dataclass
class DeltaUF():
parent_uf : UF
duf : dict[object,object] = field(default_factory=dict)
def find(self, x):
#while x in self.duf:
# x = self.duf[x]
#return self.parent_uf.find(x)
while True:
cur_x = x
while x in self.duf:
x = self.duf[x]
x = self.parent_uf.find(x)
if x == cur_x:
return x
def union(self, x, y):
x,y = self.find(x), self.find(y)
if x != y:
if x < y:
x,y = y,x
self.duf[x] = y
return y
def rebuild(self):
for k in self.duf.keys(): # There's some monkey business here about accidental self loops
self.duf[k] = self.find(k)