There is an approach to problem solving where you take simpler and simpler subproblems until they are quite obvious, and then you build back up to the thing you want.

People want an egraph that supports associativity and commutativty without blowing up.

A step in that direction is to look at what an AC hash cons looks like

These are other steps towards that goal:

Two Kinds of Nodes

One node is a regular ordered children node and the other is one where only the multiset of children is relevant. A simple way of implementing multisets is to use tuples where you only cosntruct them in sorted form.

from dataclasses import dataclass, field

type Id = int # Integer identitfiers for hashconsed nodes

@dataclass(frozen=True)
class App:
    f : str
    args : tuple[Id, ...]

@dataclass(frozen=True)
class ACApp:
    f : str
    args : tuple[Id, ...]
type Node = App | ACApp

This is very similar to the notion of “containers”. It can also be thought of as an adt

type term =
    | App of term list
    | ACApp of term multiset
    | PolyApp of term poly
    | SetApp of term set
    | ...

AC Hash Consing

The flattened sorted form of an AC symbol is a normal form.

Smart constructors can do this flattening and sorting. Because we can assume inductviely that things are already flattened, we don’t have to do a deep search

@dataclass
class ACHashCons:
    memo : dict[Node, Id] = field(default_factory=dict) # IndexMap
    nodes : dict[Id, Node] = field(default_factory=dict)

    def mk_node(self, node : Node) -> Id:
        if node in self.memo:
            return self.memo[node]
        else:
            id = len(self.nodes)
            self.nodes[id] = node
            self.memo[node] = id
            return id

    def mk_app(self, f : str, args : tuple[Id, ...]) -> Id:
        return self.mk_node(App(f, args))


    def mk_acapp(self, f : str, args : tuple[Id, ...]) -> Id:
        # immediate smart constructor flattening
        flatargs = [] 
        for id in args:
            assert isinstance(id, int)
            n = self.nodes[id]
            if isinstance(n, ACApp) and n.f == f:
                flatargs.extend(n.args)
            else:
                flatargs.append(id)
        flatargs.sort()
        return self.mk_node(ACApp(f, tuple(flatargs)))

hc = ACHashCons()
a = hc.mk_app("a", ())
b = hc.mk_app("b", ())
pab = hc.mk_acapp("+", (a, b))
pba = hc.mk_acapp("+", (b, a))
assert pab == pba
pbb = hc.mk_acapp("+", (b, b))
pabb = hc.mk_acapp("+", (pab, b))
pbba = hc.mk_acapp("+", (pbb, a))
 
assert pabb == pbba
assert len(hc.nodes[pabb].args) == 3

pabb2 = hc.mk_acapp("+", (a, b, b)) # We don't have to make the intermediate nodes
assert pabb == pabb2

We don’t have to make AC symbols two operands at a time. It may be better to only construct the ACNode one you have all it’s children.

The main time one wants to intern an AC symbol is when it is “sealed off” by being under a new symbol like f(a + b + c). Otherwise it may be better to keep the AC thing ephemeral like in ACId (see below).

AC Matching

AC matching does require search and can return multiple matching substitutions.

First a simple language of patterns

@dataclass(frozen=True)
class Var:
    name : str

@dataclass(frozen=True)
class PApp:
    f : str
    args : tuple['Pattern', ...]

@dataclass(frozen=True)
class PACApp:
    f : str
    args : tuple['Pattern', ...]

type Pattern = Var | PApp | PACApp

Then we implement the thing itself. At each AC node we need to enumerate all ways of splitting the children into partitions

A brute force way of doing so is this

import itertools
def partitions(s, n):
    for assign in itertools.product(range(n), repeat=len(s)):
        bins = [[] for _ in range(n)]
        for x,b in zip(s, assign):
            bins[b].append(x)
        if all(bin for bin in bins):
            yield tuple(tuple(bin) for bin in bins)

list(partitions("abc", 2))
[(('a', 'b'), ('c',)),
 (('a', 'c'), ('b',)),
 (('a',), ('b', 'c')),
 (('b', 'c'), ('a',)),
 (('b',), ('a', 'c')),
 (('c',), ('a', 'b'))]

Mixing Ephemeral and Interned Terms

It is useful to consider if you want multiple representations of terms. You may want a way to have both ephemeral and interned terms or multiple levels of interning. ACId is kind of a like an uninterned AC node. We produce these things ephemerally during matching and I think it’s kind of useful and interesting to have them.

from dataclasses import dataclass,field
import itertools
type Id = int

@dataclass(frozen=True)
class ACId:
    f : str
    args : tuple[Id, ...]

type SId = Id | ACId # structured Ids. ACApp may be uninterned / ephemeral
# obviously it hurts to keep making copies of the same struct, but they are noticeably different

This paper about the Kontroli term representation talks about long lived and short lived terms https://dl.acm.org/doi/10.1145/3573105.3575686 . I also mentioned it in my omelets need onions paper https://arxiv.org/abs/2504.14340 . Chris also mentioned usage of “scoped hash conses” https://github.com/bytecodealliance/wasmtime/blob/main/cranelift/codegen/src/scoped_hash_map.rs in the cranelift compiler.

Doing the matching

Inspired by Max’s nice version of ematching. There is some similarity between ematching in an egraph and ac-matching. They are both search based with an obvious finite thing to search out of. I pointed out this similarity here https://www.philipzucker.com/relational-ac-matching/ where I wondered if AC is as good a fit to the relational idea as ematching is.

Basically you traverse down the pattern. You need to try all partitions of the arguments when you hit an AC pattern.

def acmatch_rec(hc : ACHashCons, p : Pattern, id : SId, subst) -> list[dict[str, SId]]:
    if isinstance(id, ACId):
        n = id
    else:
        n = hc.nodes[id]
        if isinstance(n, ACApp): # normalize into ACId to simplify logic below
            n = ACId(n.f, n.args) # could move this below Var for efficiency, but whatev
    match p, n:
        case Var(x), _:
            # If variable check if binding is consistent
            if x in subst:
                if subst[x] == id: # something smarter for ACId?
                    return [subst]
                else:
                    return []
            else:
                newsubst = subst.copy()
                newsubst[x] = id
                return [newsubst]
        case PApp(pf, pargs), App(f, args) if pf == f and len(pargs) == len(args):
            substs = [subst]
            for parg, arg in zip(pargs, args):
                substs = [subst for subst0 in substs for subst in acmatch_rec(hc,parg, arg, subst0)]
            return substs
        case PACApp(pf, pargs), ACId(f, args) if pf == f:
            # we need to hand down temporary ACNodes
            # What about ACNode with 1 element? ACI easier?
            res = []
            # enumerating partitions more lazily would prune earlier.
            # I suppose the binary form of normal form of ac would do this for free...
            # No. Just _don't_  flatten your patterns.
            for subargs in partitions(args, len(pargs)):
                # We consider singleton usage +(a) to be the identity `plus(a) = a`
                subargs = map(lambda a: ACId(f, a) if len(a) > 1 else a[0], subargs)
                substs = [subst]
                for parg, arg in zip(pargs, subargs):
                    substs = [subst for subst0 in substs for subst in acmatch_rec(hc, parg, arg, subst0)]
                res.extend(substs)
            return res
        case _:
            return []

    
def acmatch(hc : ACHashCons, p : Pattern, id : Id) -> list[dict[str, Id]]:
    return acmatch_rec(hc, p, id, {})

Some example output

hc = ACHashCons()
a = hc.mk_app("a", ()) # id 0
b = hc.mk_app("b", ()) # id 1
c = hc.mk_app("c", ()) # id 2
pabc = hc.mk_acapp("+", (a, b, c)) # id 3
fa  = hc.mk_app("f", (a,)) # id 4
ffa = hc.mk_app("f", (fa,)) # id 5
paab = hc.mk_acapp("+", (a, a, b)) # id 6

acmatch(hc, PACApp("+", (Var("x"), Var("y"))), pabc)
[{'x': ACId(f='+', args=(0, 1)), 'y': 2},
 {'x': ACId(f='+', args=(0, 2)), 'y': 1},
 {'x': 0, 'y': ACId(f='+', args=(1, 2))},
 {'x': ACId(f='+', args=(1, 2)), 'y': 0},
 {'x': 1, 'y': ACId(f='+', args=(0, 2))},
 {'x': 2, 'y': ACId(f='+', args=(0, 1))}]
acmatch(hc, PACApp("+", (Var("x"), PACApp("+", (Var("z"), Var("y"))))), pabc)
[{'x': 0, 'z': 1, 'y': 2},
 {'x': 0, 'z': 2, 'y': 1},
 {'x': 1, 'z': 0, 'y': 2},
 {'x': 1, 'z': 2, 'y': 0},
 {'x': 2, 'z': 0, 'y': 1},
 {'x': 2, 'z': 1, 'y': 0}]
acmatch(hc, PACApp("+", (Var("x"), PACApp("+", (Var("x"), Var("y"))))), pabc)
[]
acmatch(hc, PACApp("+", (Var("x"), PACApp("+", (Var("x"), Var("y"))))), paab)
[{'x': 0, 'y': 1}, {'x': 0, 'y': 1}]
acmatch(hc, PACApp("+", (PApp("a", ()), PACApp("+", (PApp("b", ()), Var("y"))))), pabc)
[{'y': 2}]
acmatch(hc,PApp("f", (Var("x"),)), ffa)
[{'x': 4}]

Bits and Bobbles

AC congruence closure and AC ground completion are pre-existing things. AC term orderings are a thing (ACRPO and ACKBO). A nice thing about e-graphs compared to this is that it doesn’t require so many complex conceptual priors

AC terms are conceptually quite simple to normalize. You flatten and sort. This can be thought of as a special kind of node that stores a multiset of its children rather than a vector.

AC matching is more challenging than simple syntactic matching, but not bad conceptually.

We only generate matches where a single application of the AC symbol is the identity and there is no 0-arity application. It may be reasonable to consider also allowing identity elements (0 for + or 1 for *) into the match. This should’t change that much. Set-like matching also shouldn’t be that different.

Take microegg, replace AC with just a hash cons https://www.philipzucker.com/weighted_uf/ do this with lpo? just store nxn lpo results? Well, depth can short circuit it a little and max symbol There is also ACKBO

The +{x} singleton kind of has an idempotence thing to it. Pattern matching on that would branch on stripping it or not.

You can also do bottom up AC matching.

It’d be better to return the first bin and a generator of the remaining bins for early pruning. Whatever

  1. Unique pointers
  2. Reference counted pointers
  3. Interned Ids
  4. Multiple layers of referenced Ids (scoped hash consing). Perhaps a global and ephemeral arena.

You may want to intermix the representations. Hash conses can serve at least the following purposes:

  1. Make structural and physical/pointer equality the same thing. Make equality checks faster.
  2. Make a term bank you can search for stuff in
  3. Reduce duplication for reduced memory usage

You don’t need everything to be perfectly hash consed. You can do structural checks until you reach a point where you know perfect deduplication has been achieved.

For this reason it seems useful to have a notion of ACId sometimes in addition to a regular Id

import itertools
def partitions(s, n):
    for assign in itertools.product(range(n), repeat=len(s)):
        bins = [[] for _ in range(n)]
        for x,b in zip(s, assign):
            bins[b].append(x)
        # possibly filter on nonempty or singleton or do map into ACId here
        yield tuple(tuple(bin) for bin in bins)

list(partitions("abc", 2))
[(('a', 'b', 'c'), ()),
 (('a', 'b'), ('c',)),
 (('a', 'c'), ('b',)),
 (('a',), ('b', 'c')),
 (('b', 'c'), ('a',)),
 (('b',), ('a', 'c')),
 (('c',), ('a', 'b')),
 ((), ('a', 'b', 'c'))]

Inspired by Max’s nice version of ematching, here is an AC matcher in a similar form


I’m a little shaky about the length 0 and length 1 children case for ACNodes. It is not clear to me if they should be allowed or not.

from dataclasses import dataclass,field
import itertools
type Id = int

@dataclass(frozen=True)
class ACId:
    f : str
    args : tuple[Id, ...]

type SId = Id | ACId # structured Ids. ACApp may be uninterned / ephemeral
# obviously it hurts to keep making copies of the same struct, but they are noticeably different

@dataclass(frozen=True)
class App:
    f : str
    args : tuple[Id, ...]

@dataclass(frozen=True)
class ACApp:
    f : str
    args : tuple[Id, ...]
type Node = App | ACApp

@dataclass(frozen=True)
class Var:
    name : str

@dataclass(frozen=True)
class PApp:
    f : str
    args : tuple['Pattern', ...]

@dataclass(frozen=True)
class PACApp:
    f : str
    args : tuple['Pattern', ...]

type Pattern = Var | PApp | PACApp

@dataclass
class ACHashCons:
    memo : dict[Node, Id] = field(default_factory=dict) # IndexMap
    nodes : dict[Id, Node] = field(default_factory=dict)

    def mk_node(self, node : Node) -> Id:
        if node in self.memo:
            return self.memo[node]
        else:
            id = len(self.nodes)
            self.nodes[id] = node
            self.memo[node] = id
            return id

    def mk_app(self, f : str, args : tuple[SId, ...]) -> Id:
        args = tuple(self._mk_acapp(arg.f, arg.args) if isinstance(arg, ACId) else arg for arg in args)
        return self.mk_node(App(f, args))

    # maybe you should never make an ac_app directory, only mk_app should call it? Is this overcomplicating it?
    #def mk_acapp(self, f : str, args : tuple[SId, ...]) -> SId:
        #args = tuple(args if isinstance(arg, int) else self._mk_acapp(arg.f, arg.args) for arg in args)
        # should also flatten
        # Should it flatten by searching into nodes too? No i think
        #flatargs = []
    #    return ACId(f, args)

    def mk_acapp(self, f : str, args : tuple[Id, ...]) -> Id:
        flatargs = [] 
        for id in args:
            assert isinstance(id, int)
            n = self.nodes[id]
            if isinstance(n, ACApp) and n.f == f:
                flatargs.extend(n.args)
            else:
                flatargs.append(id)
        flatargs.sort()
        return self.mk_node(ACApp(f, tuple(flatargs)))

    def acmatch_rec(self, p : Pattern, id : SId, subst) -> list[dict[str, Id]]:
        if isinstance(id, ACId):
            n = id
        else:
            n = self.nodes[id]
            if isinstance(n, ACApp):
                n = ACId(n.f, n.args) # could move this below Var for efficiency, but whatev
        match p, n:
            case Var(x), _:
                if x in subst:
                    if subst[x] == id:
                        return [subst]
                    else:
                        return []
                else:
                    newsubst = subst.copy()
                    newsubst[x] = id
                    return [newsubst]
            case PApp(pf, pargs), App(f, args) if pf == f and len(pargs) == len(args):
                substs = [subst]
                for parg, arg in zip(pargs, args):
                    substs = [subst for subst0 in substs for subst in self.acmatch_rec(parg, arg, subst0)]
                return substs
            case PACApp(pf, pargs), ACId(f, args) if pf == f:
                # we need to hand down temporary ACNodes
                # What about ACNode with 1 element? ACI easier?
                res = []
                for subargs in partitions(args, len(pargs)):
                    subargs = map(lambda a: ACId(f, a), subargs)
                    substs = [subst]
                    for parg, arg in zip(pargs, subargs):
                        substs = [subst for subst0 in substs for subst in self.acmatch_rec(parg, arg, subst0)]
                    res.extend(substs)
                return res
            case _:
                return []

        
    def acmatch(self, p : Pattern, id : Id) -> list[dict[str, Id]]:
        # flatten pattern?
        return self.acmatch_rec(p, id, {})

#for partition in itertools.combinations(args, len(pargs)): # wrong
#    for subargs in itertools.permutations(partition

hc = ACHashCons()
a = hc.mk_app("a", ())
b = hc.mk_app("b", ())
c = hc.mk_app("c", ())
pabc = hc.mk_acapp("+", (a, b, c))
fa  = hc.mk_app("f", (a,))
ffa = hc.mk_app("f", (fa,))

hc.acmatch(PACApp("+", (Var("x"), Var("y"))), pabc)
hc.acmatch(PACApp("+", (Var("x"), PACApp("+", (Var("z"), Var("y"))))), pabc)
hc.acmatch(PACApp("+", (Var("x"), PACApp("+", (Var("x"), Var("y"))))), pabc)
hc.acmatch(PApp("f", (Var("x"),)), ffa)

[{'x': 4}]

Post hoc equalities may add new flattenings, so rebuilding involve AC flattening to make sure all AC flat forms exists. Rebuilding becomes way more expensive The Aegraph style of eagerly applying rewrites may avoid needing as much rebuilding though.

Rudi points out, what about right associated sorted regular hash consing. What about KBO on that?

Multisets of multiset rewriting. Take the https://www.philipzucker.com/multiset_rw/ Nested multisets Non well founded multisets? The AC to the ACI of non well founded sets

https://pdfs.semanticscholar.org/08a6/9eb42f087ce76876aef224d32f2109367307.pdf modular mutliset rewriting

ACKBO

http://cl-informatik.uibk.ac.at/software/ackbo/ https://arxiv.org/abs/1403.0406

Ok the simplest ordering they describe puts all AC simple low, so they want to get pushed to leaved

ACRPO

I still want to implement acrpo.

https://research.vu.nl/ws/portalfiles/portal/266879340/A_Lambda_Free_Higher_Order_Recursive_Path_Order.pdf implement lfrpo

Ground nominal term orderings.

Equivariance means that X = Y can’t be oriented. or add(X,Y) = add(Y,X) But we’re ground so we’re probably better behaved

Rudi noted that slots are lot of like unification vars that never get narrowed and have a disjointness constraint. https://egraphs.zulipchat.com/#narrow/channel/328972-general/topic/EGRAPHS.20Meeting.202025-03-20.3A.20Slotted.20E-Graphs/near/543056830

If t =g= s was useful, what about t <g< s. Hmm. Ordered modulo g. Then we could talk about x + y < y + x ?

2 <(-7)< -4 using shift group. Yes… also 2 <-8< -4 but not 2 <-6< -4 Combininig the inequality union find with the group union find.

An ordered group? https://en.wikipedia.org/wiki/Linearly_ordered_group https://ncatlab.org/nlab/show/ordered+group https://en.wikipedia.org/wiki/Partially_ordered_group

def perm_lpo(t, s) -> tuple[Order, G]:

https://memoryleak47.github.io/slotted-egraphs-demystified-i/ https://medium.com/@mineralsteins/ultimate-code-compression-with-slotted-egraph-74f1859c9df5

So we need an unfailing variant of knuth bendix

renamings aren’t bijective. They are 1-1.

https://www.cs.ru.nl/~cynthiakop/2024_isr/day2handout.pdf Termination: the higher-order recursive path ordering - Kop

Have I done dumb ground kb in python?

https://drops.dagstuhl.de/storage/00lipics/lipics-vol021-rta2013/LIPIcs.RTA.2013.319/LIPIcs.RTA.2013.319.pdf normalized completion revisited. 2013. Has a table of moidifed rules. hmm. acrpo better than ackbo?

https://www.jonmsterling.com/01JQ/ disentangling unfication and implciti coercions “nested pattern unification” https://andraskovacs.github.io/pdfs/wits23prez.pdf

class ACEGraph()
    def add_ac(self, f, args):
        flatargs = [[]]
        for id in args:
            this_arg = [[id]]
            for node in self.class[id]:
                match node:
                    case AC(f1, args1) if f1 == f:
                        this_arg.append(args1)
                    case _:
                        pass
            # expand out all ways of flattening
            flatargs = [prev + this for prev in flatargs for this in this_arg]
        children = []
        for args in flatargs: # maybe dedup first. Or whatever. add_prim will do that
            seld.add_prim(AC(f, sorted(args)))

        if AC(f, sorted(args)) in memo:
            return memo
        else:
            # do eager flattening
            # add in all E equivalent flattened terms
            for sub in powerset(args)
                if AC(f, sub) in memo:
                    id = mem(AC(f, sub))
                    AC
    def rebuild():
        add in flattenings
        
    def ematch(pattern, id)::
        for node in eclasses[id]:
            match node:
                case App
                case ACApp:
                    for all partitions:
                        match
  Cell In[14], line 1
    class ACEGraph()
                    ^
SyntaxError: expected ':'

A simple Manual Ground KB System

from dataclasses import dataclass
@dataclass(frozen=True)
class App():
    f : str
    args : tuple["Term", ...] = ()

    def __len__(self) -> int:
        return 1 + sum(len(arg) for arg in self.args)

    def __repr__(self) -> str:
        if self.args:
            return f"{self.f}({', '.join(repr(arg) for arg in self.args)})"
        else:
            return self.f

    def __le__(self, other: "App") -> bool:
        return self == other or self < other
    def __lt__(self, other: "App") -> bool:
        if self == other:
            return False
        ls,lo = len(self), len(other)
        if ls != lo:
            return ls < lo
        if self.f != other.f:
            return self.f < other.f
        for sa, oa in zip(self.args, other.args):
            if sa != oa:
                return sa < oa
        return False
    def subst(self, a, b):
        if self == a:
            return b
        return App(self.f, tuple(arg.subst(a,b) for arg in self.args))    

x = App("x")
y = App("y")
f = lambda a,b: App("f", (a,b))
assert f(x,y) > x
assert y > x
assert not x > y
assert not x < x

@dataclass
class KB():
    E : list[tuple[App,App]]
    R : list[tuple[App,App]]
    # Ordering

    def orient(self, n: int):
        (a,b) = self.E.pop(n)
        assert b != a
        if b > a:
            a,b = b,a
        self.R.append((a,b))
        return self
    
    def triv(self, n : int):
        (a,b) = self.E.pop(n)
        assert a == b
        return self
    
    def simp(self, en : int, rn : int):
        (ae,be) = self.E.pop(en)
        (ar,br) = self.R[rn]
        self.E.append((ae.subst(ar,br), be.subst(ar,br)))
        return self
    
    def deduce(self, u : App, rleft : int, rright: int):
        (arl,brl) = self.R[rleft]
        (arr,brr) = self.R[rright]
        self.E.append((u.subst(arl,brl),
                       u.subst(arr,brr)))
        return self
    
    def rsimp(self, rn : int, rm : int):
        (ar,br) = self.R[rn]
        (am,bm) = self.R.pop(rm)
        self.R.append((am, bm.subst(ar,br)))
        return self

    def lsimp(self, rn : int):
        # TODO: need to dignify with encompassment condition
        self.E.append(self.R.pop(rn))
        return self



kb = KB([(f(x,y), y), (f(x,y), x)], [])

kb
kb.orient(0)
kb.simp(0,0)
kb
kb.orient(0)
kb.rsimp(1,0)
kb.deduce(f(x,y), 0, 1)
kb.orient(0)
kb.lsimp(1)
kb.simp(0,0)
kb.simp(0,1)
kb.triv(0)
KB(E=[], R=[(y, x), (f(x, x), x)])
# ok but it hsould be schedules
from dataclasses import dataclass, field

counter = 0
def fresh_counter() -> int:
    global counter
    counter += 1
    return counter

@dataclass(frozen=True)
class Slot():
    name : int = field(default_factory=fresh_counter)

@dataclass(frozen=True)
class App():
    f : str
    args : tuple["Term", ...] = ()

type Term = App | Slot