SMTLIB as a Compiler IR I
I like SMT solvers. Compilers are cool. What kind of babies can they make?
A design trick that has lead me to interesting places is to abuse the z3py AST more thoroughly than any sane person would do. Z3 already has very reasonable AST for describiing logic, bitvector operations, functions, reals, and integers. But, if you do it right, in addition to just an AST, you also get semantics and a magic solver.
Compilers are nice because they are a pretty well specified problem that is actually useful. Reasoning principles and technology can be applied to make code faster. Bad reasoning can make output code buggy even when the input wasn’t.
There are at least two ways to approach what a compiler IR is:
- It is basically pure expressions that we start bolting stateful stuff onto
- It is basically imperative code / slightly abstracted assembly that we sometimes find pure subpieces of to help do optimization
I tend to be 1 and that is the approach I’ll be taking today.
SSA is Functional Programming
One of my favorite papers is SSA is Functional Programming https://www.cs.princeton.edu/~appel/papers/ssafun.pdf . From this perspective, the core of SSA is a bunch of mutually defined recursive definitions.
Single Static Assignment (SSA) https://en.wikipedia.org/wiki/Static_single-assignment_form is a compiler IR that people have noticed makes some compilation subproblems more straightforward. Variables can be only assigned once. You can kind of do this by making any subsequent assignments go to a fresh variable and replace later reads by reading from that fresh variables.
We can take the example program from the paper and write it in python
%%file /tmp/myfun.py
def myfun():
i = 1
j = 1
k = 0
while k < 100:
if j < 20:
j = i
k = k+1
else:
j = k
k = k + 2
return j
print(myfun())
Overwriting /tmp/myfun.py
! python /tmp/myfun.py
1
For fun, I’ll build a cfg of this function using a package
from py2cfg import CFGBuilder
cfg = CFGBuilder().build_from_file('myfun', '/tmp/myfun.py')
cfg.build_visual('exampleCFG', 'svg')
We can break this up into one function per block. Since python doesn’t have tail call optimization, this is a ludicrous thing to do from python’s perspective, but it does put the thing into a normal form.
You can see that each one of these functions corresponds to a block above. These programs compute the same thing.
def myfun():
return loop(1,1,0)
def loop(i,j,k): # orange block
return if_head(i,j,k) if k < 100 else done(i,j,k)
def if_head(i,j,k): # red block
return then(i,j,k) if j < 20 else else_(i,j,k)
def then(i,j,k):
return loop(i, i, k + 1)
def else_(i,j,k):
return loop(i, k, k + 2)
def done(i,j,k): # green block
return j
myfun()
1
Turning it Into SMT
One of the important features of my system knuckledragger is that it supports definitions. These definitions are registered and unfolded via the z3 function substitute_funs
We can replicate exactly this structure above and now we have a CFG-like thing in our logic thing. Neat!
from kdrag.all import *
Z = smt.IntSort()
# predeclare all our blocks so that we can recursively call them
myfun = smt.Function("myfun", Z)
loop = smt.Function("loop", Z,Z,Z, Z)
if_head = smt.Function("if_head", Z,Z,Z, Z)
then = smt.Function("then", Z,Z,Z, Z)
else_ = smt.Function("else_", Z,Z,Z, Z)
done = smt.Function("done", Z,Z,Z, Z)
i,j,k = smt.Ints("i j k")
myfun = kd.define("myfun", [], loop(1,1,0))
loop = kd.define("loop", [i,j,k], smt.If(k < 100, if_head(i,j,k), done(i,j,k)))
if_head = kd.define("if_head", [i,j,k], smt.If(j < 20, then(i,j,k), else_(i,j,k)))
then = kd.define("then", [i,j,k], loop(i, i, k + 1))
else_ = kd.define("else_", [i,j,k], loop(i, k, k + 2))
done = kd.define("done", [i,j,k], j)
define makes a definitional theorem for each of these. Here is if_head’s for example
if_head.defn
⊨ForAll([i, j, k], if_head(i, j, k) == If(j < 20, then(i, j, k), else_(i, j, k)))
The full_simp function interleaves z3.substitute_funs and z3.simplify until the expression stops changing. We can use it to run the program or any other concrete definitions.
kd.full_simp(kd.full_simp(myfun))
1
Making it more IR-y
However, is this a compiler IR? Doesn’t super look like one?
Well, this is to some degree of matter of printing. If you use the right sigils and formatting, things look more like a compiler IR.
Compiler IRs typically have a sequence of simple operations. Operations are simple flat things things like add %x, %y but not compound things like (add (add (add x y) z) (add x y)).
SMTLIB is a pure logic. There isn’t really a notion of sequencing or assignment as one might have in an imperative language or compiler IR. However, we can expand our expressions into such a sequence basically by traversing them in order and storing the subexpressions in a list.
It is common (to my understanding) that the “names” in the textual form of SSA rarely actually appear in the data structure of SSA. The variable is represented often by a pointer to the operation that produced it. They are basically the same thing or can be conflated/coerced to be the same thing.
In my printer I do the same thing. I print subexpression using their id, which is a unique number z3 supplies via hash consing.
import kdrag.contrib.ir as ir
blk = ir.Block([Z,Z,Z], kd.kernel.defns[loop]._subst_fun_body)
ir.Block.of_defined_fun(loop)
^(Int,Int,Int):
%0 = < %var2, 100
%1 = if_head %var0, %var1, %var2
%2 = done %var0, %var1, %var2
%3 = if %0, %1, %2
I can also print a function consisting of mutually defined blocks. Really this isn’t doing anything much to the expressions. It’s a printing choice.
ir.Function.of_defined_funs([myfun.decl(), loop, if_head, then, else_, done])
fn myfun {
@myfun:
^():
%0 = loop 1, 1, 0
@loop:
^(Int,Int,Int):
%0 = < %var2, 100
%1 = if_head %var0, %var1, %var2
%2 = done %var0, %var1, %var2
%3 = if %0, %1, %2
@if_head:
^(Int,Int,Int):
%0 = < %var1, 20
%1 = then %var0, %var1, %var2
%2 = else_ %var0, %var1, %var2
%3 = if %0, %1, %2
@then:
^(Int,Int,Int):
%0 = + %var2, 1
%1 = loop %var0, %var0, %0
@else_:
^(Int,Int,Int):
%0 = + %var2, 2
%1 = loop %var0, %var2, %0
@done:
^(Int,Int,Int):
}
Bits and Bobbles
Next time, maybe I’ll talk about taking textual QBE and converting it to SMT by
- Turning phi nodes into block args. Kind of push them up to the blocks they came from.
- explicitizing memory (use the smt theory of arrays)
- CSE
I think I like QBE.
Max has some great posts on IRs and SSA https://bernsteinbear.com/blog/ssa/ https://bernsteinbear.com/blog/irs/
Michel was pointing out to me that in the SSA is FP paper it mentions that the dominator structure can be reflected in nested let bindings. That’s pretty cool. I don’t really have let is knuckledragger sadly / z3py doesn’t offer it. I kind of wish it did.
Modelling a CFG as a constrained horn clause is an alternative. It is the logic programming version of SSA is functional programming. It’s more predicaty than equational. https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/nbjorner-yurifest.pdf https://www.philipzucker.com/bap-chc/
SMTLIb dialect of MLIR https://mlir.llvm.org/docs/Dialects/SMT/ https://github.com/opencompl/xdsl-smt
MimIR https://www.arxiv.org/pdf/2411.07443 . Really neat. Trying to make a new MLIR like thing based around the SSA is functional programming style
Whitequark and wanda projunnamed https://mastodon.social/@whitequark/113970437064821618 https://github.com/prjunnamed/prjunnamed
One curious thing that I’m not sure what to do with is that in a more typical IR, the tail calls would be part of the if operation at the end. Instead I have them early which looks weird from an imperative perspective. But I tried changing my Block structure to look more like this and I didn’t like it.
What’s intriguing about using SMT as my IR is how maybe I can use to to verify optimizations or synthesize optimizations.
Writing a new AST is a lot of bulk laborious work. It leads to decision fatigue and for blogging purposes it is too much bulk stuff. Designing an ast is it’s own blog post easily. Especially in python, where the language does not make it succinct to define new node types. Dataclasses help but it isn’t great to keep having to write class over and over again.
Having said that, I do feel the pain in my project Knuckledragger. “If only I could change just this little thing about z3 or add this little feature.” But it also keeps me honest.
“The method of ‘postulating’ what we want has many advantages ; they are the same as the advantages of theft over honest toil.” - Russell
Ironically, the same crowd of people that may abhor cheating with axioms sometimes loves the idea of cheating by changing the nature of their logic to make this or that slicker or more automatic. This is from a certain perspective an even deeper version of postulating what you want to be true to be true by fiat. And having a logical system who I can’t recognize the relation to a more conventional logic is even more suspect and uneasy making than some funky axiom.
parsing block args qbe
Been tinkering on a variation of qbe that takes block args and explicit memory passing. Kind of neat.
import lark
from kdrag.all import *
from kdrag.contrib.ir import *
grammar = r"""
start : NL* funcdef NL*
funcdef: "function" GLOBAL NL? "{" NL block+ "}"
block: LABEL "(" [param_list] ")" NL instrs jump NL
instrs : instr*
instr: TEMP "=" BASETY OP operand ("," operand)* NL
?jump: call | ite | ret
ite : "ite" operand "," call "," call
// jmp: "jmp" call
//jnz: ("jnz" | "ite") val "," call "," call
ret: "ret" [operand] // make a call?
call : LABEL "(" [call_param_list] ")"
operand: SIGNED_INT | TEMP | GLOBAL
OP: /[A-Za-z][A-Za-z0-9]+/
param_list: param ("," param)*
param : BASETY TEMP
call_param_list : operand ("," operand)*
GLOBAL: /\$[A-Za-z_.][A-Za-z0-9_.]*/
LABEL: /@[A-Za-z_.][A-Za-z0-9_.]*/
TEMP: /%[A-Za-z_.][A-Za-z0-9_.]*/
BASETY: "w" | "l" | "s" | "d" | "m" | "b" // add m for memory
%import common.WS_INLINE
%import common.NEWLINE
%import common.ESCAPED_STRING
%import common.SIGNED_INT
%ignore WS_INLINE
%ignore /#[^\n]*/
NL: NEWLINE+
"""
PARSER = lark.Lark(grammar, start="start", parser="lalr")
example = r"""
function $foo{ #(m %mem){
@entry(w %x, w %y, m %mem)
%t =w bvadd %x, %y
%t2 =w bvadd %t, 42
ret %mem
@loop(m %mem, w %x)
%t3 =w bvneg %x
@entry(%t3)
}
"""
PARSER.parse(example)
basety = {
"b" : smt.BoolSort(),
"w" : smt.BitVecSort(32),
"l" : smt.BitVecSort(64),
"m" : smt.ArraySort(smt.BitVecSort(64), smt.BitVecSort(8)),
}
@lark.v_args(inline=True)
class FunctionTransformer(lark.Transformer):
def __init__(self):
self.labels : dict[str, smt.FuncDeclRef] = {}
self.temps : dict[str, smt.ExprRef] = {}
self.funcs : dict[str, smt.FuncDeclRef] = {}
def get_temp(self, name : str):
if name not in self.temps:
raise Exception(f"Unknown temp {name}")
return self.temps[name]
def start(self, nl1, *funcs):
return funcs[:-1] # nl
def funcdef(self, name, _params, nl1, *blocks):
#print("funcdef", blocks)
return Function("entry", {k : v for k,v in blocks})
def block(self, label, params, nl1, insns, jmp, nl2):
#print("block", label, "params", params, "insn", insns, "jm", jmp)
return (label.value[1:], Block(params, insns + [jmp]))
def BASETY(self, ty):
return basety[ty.value]
def param(self, ty, name):
#print(name, name[1:])
self.temps[name[1:]] = smt.Const(name[1:], ty)
return smt.Const(name[1:], ty)
def param_list(self, *params):
return list(params)
def operand(self, op):
#print("operand", op)
if isinstance(op, lark.Token):
if op.type == "SIGNED_INT":
return smt.BitVecVal(int(op.value), 32)
elif op.type == "TEMP":
#print(self.temps, op.value)
return self.temps[op.value[1:]]
elif op.type == "GLOBAL":
return smt.Const(op.value[1:], smt.BitVecSort(64))
else:
raise Exception("Unknown operand type")
def instrs(self, *instrs):
return list(instrs)
def instr(self, dest, ty, op, *operands):
#print(operands)
c = smt.Const(dest.value[1:], ty)
#print(c)
if op == "bvadd":
expr = operands[0] + operands[1]
elif op == "bvneg":
expr = -operands[0]
elif op == "bvsub":
expr = operands[0] - operands[1]
elif op == "bveq":
return operands[0] == operands[1]
else:
raise Exception(f"Unknown op {op}")
dname = dest.value[1:]
if dname in self.temps:
raise Exception(f"Reassignment to {dname}")
self.temps[dname] = expr
return expr
def call_param_list(self, *params):
return list(params)
def call(self, label, params):
#print("call", label, params, self.temps)
f = smt.Function(label.value[1:], *[p.sort() for p in params], Bottom)
return f(*params)
def ret(self, v):
return smt.Function("ret", v.sort(), Bottom)(self.temps[v.decl().name()])
def ite(self, cond, call_true, call_false):
return smt.If(cond, call_true, call_false)
def NL(self, token):
return None
#def TEMP(self, name):
# #print(name)
# #return self.temps[name]
def parse(s : str):
tree = PARSER.parse(s)
builder = FunctionTransformer()
return builder.transform(tree)
parse(example)[0]
sumn = r"""
function $sumn {
@entry(w %n)
@loop(%n, 0)
@loop (w %i, w %acc)
%acc1 =w bvadd %acc, %i
%i1 =w bvsub %i, 1
%c =b bveq %i1, 0
ite %c, @loop(%i1, %acc1), @done(%acc1)
@done(w %res)
ret %res
}
"""
parse(sumn)[0].blocks["loop"]
Correspondence
loop(e1,e2,e3) and loop expects e1
Do I need to use th
%%file /tmp/myfun.s
.global myfun
.equ i, %rdi
.equ j, %rsi
.equ k, %rdx
loop:
ret # todo
myfun:
mov 1, i
mov 1, j
mov 0, k
jmp loop
nop
Overwriting /tmp/myfun.s
! gcc -c -o /tmp/myfun /tmp/myfun.s
! objdump -d /tmp/myfun
/tmp/myfun: file format elf64-x86-64
Disassembly of section .text:
0000000000000000 <loop>:
0: c3 ret
0000000000000001 <myfun>:
1: 48 8b 3c 25 01 00 00 mov 0x1,%rdi
8: 00
9: 48 8b 34 25 01 00 00 mov 0x1,%rsi
10: 00
11: 48 8b 14 25 00 00 00 mov 0x0,%rdx
18: 00
19: eb e5 jmp 0 <loop>
1b: 90 nop
import kdrag.contrib.pcode as pcode
ctx = pcode.BinaryContext("/tmp/myfun")
memstate0 = pcode.MemState.Const("memstate0")
memstate1 = ctx.execute_block(memstate0, ctx.loader.find_symbol("myfun").rebased_addr)
Unexpected SP conflict
[SimState(memstate=MemState((let ((a!1 (store64le (store64le (register memstate0)
&RDI
(select64le (ram memstate0) #x0000000000000001))
&RIP
#x0000000000400000)))
(let ((a!2 (store64le (store64le a!1
&RSI
(select64le (ram memstate0) #x0000000000000001))
&RDX
(select64le (ram memstate0) #x0000000000000000))))
(and (= CUR_RAM (ram memstate0)) (= CUR_REGFILE a!2))))), pc=(4194304, 0), path_cond=[])]
class Contract():
decl : smt.FuncDeclRef
cut : Callable[list[smt.ExprRef],MemState], smt.BoolRef]
#requires :
#asserts :
loop_contract = lambda args, memstate: smt.And(args[0] == mem.state.register("rax"), args[1] == mem.state.register("rsi"), args[2] == mem.state.register("rdx"))
loop(i,j,k) == loop(memstate.register("rax"), memstate.register("rsi"), memstate.register("rdx"))
smt.Implies(smt.ForAll([i,j,j,mem], i == mem.register(loop_high(i,j,k) == loop_low(mem), myfun_low() == myfun_high())
my_fun_high = smt.If(, loop_low(mem), )
def merge_states(states : list[SimState]):
for state in states:
jmp = smt.Function("addr_" + str(state.addr), MemState, MemState)
acc = smt.If(state.path_cond, jmp(state.mem), acc)
return acc
# predeclare all our blocks
myfun2 = smt.Function("myfun2", Z)
loop2 = smt.Function("loop2", Z,Z,Z, Z)
if_head2 = smt.Function("if_head2", Z,Z,Z, Z)
then2= smt.Function("then2", Z,Z,Z, Z)
else2_ = smt.Function("else_2", Z,Z,Z, Z)
done2 = smt.Function("done2", Z,Z,Z, Z)
i,j,k = smt.Ints("i j k")
myfun = kd.define("myfun2", [], loop(1,0))
loop = kd.define("loop2", [i,j,k], smt.If(k < 100, if_head(i,k), done(i,k)))
if_head = kd.define("if_head2", [i,j,k], then2(i,k))
then = kd.define("then2", [i,j,k], loop(i, k + 1))
done = kd.define("done2", [i,k], 1)
Reflection appears borken and unreliable. Maybe I should revisit.
# Oh yead. This is all gonna be mutually recursive. Hmm
from kdrag.all import *
from kdrag.reflect import reflect
Z = smt.IntSort()
#myfun = smt.Function("myfun", Z)
loop = smt.Function("loop", Z,Z,Z, Z)
if_head = smt.Function("if_head", Z,Z,Z, Z)
then = smt.Function("then", Z,Z,Z, Z)
else_ = smt.Function("else_", Z,Z,Z, Z)
done = smt.Function("done", Z,Z,Z, Z)
@reflect
def myfun() -> int:
return loop(1,1,0)
def loop(i : int,j : int,k : int) -> int: # orange block
return if_head(i,j,k) if k < 100 else done(i,j,k)
def if_head(i,j,k): # red block
return then(i,j,k) if j < 20 else else_(i,j,k)
def then(i,j,k):
return loop(i, i, k + 1)
def else_(i,j,k):
return loop(i, k, k + 2)
def done(i,j,k): # green block
return j
"""
SSA is Functional Programming by Andrew Appel
https://www.cs.princeton.edu/~appel/papers/ssafun.pdf
Functional programming and SSA can be put into close correspondence.
It is to some degree a matter of pretty printing.
The recipe is to define one function per block that takes in all the currently live variables as arguments.
These are also called "block arguments" and are a structural alternative to phi nodes.
SSA variables are then just references given to previous expressions.
A maximal `let` bound form can be written. https://en.wikipedia.org/wiki/A-normal_form
Jumps are calls to the other function blocks
"""
from dataclasses import dataclass, field
import kdrag as kd
import kdrag.smt as smt
from collections import defaultdict
def pp_sort(s: smt.SortRef) -> str:
if isinstance(s, smt.BitVecSortRef):
return f"bv{s.size()}"
else:
return str(s)
@dataclass
class Block:
sig: list[smt.SortRef]
insns: list[smt.ExprRef]
@classmethod
def of_defined_fun(cls, f: smt.FuncDeclRef) -> "Block":
"""
>>> x, y = smt.Ints("x y")
>>> f = kd.define("f809", [x,y], x + x + y)
>>> Block.of_defined_fun(f)
^(Int,Int):
%0 = + %var0, %var0
%1 = + %0, %var1
"""
defn = kd.kernel.defns.get(f)
if defn is None:
raise ValueError(f"Function {f} is not defined to knuckledragger")
else:
body = defn._subst_fun_body
return cls.of_expr(body, sig=[f.domain(i) for i in range(f.arity())])
@classmethod
def of_expr(cls, e: smt.ExprRef, sig=[]) -> "Block":
"""
>>> x,y = smt.BitVecs("x y", 64)
>>> x,y = smt.Var(1, smt.BitVecSort(64)), smt.Var(0, smt.BitVecSort(64))
>>> z = smt.BitVec("z", 64)
>>> Block.of_expr(smt.If(True, (x + y)*42, x - y + z), [smt.BitVecSort(64), smt.BitVecSort(64)])
^(bv64,bv64):
%0 = bvadd %var1, %var0
%1 = bvmul %0, 42
%2 = bvsub %var1, %var0
%3 = bvadd %2, z
%4 = if True, %1, %3
"""
if not smt.is_if(e):
insns = []
seen = set()
todo = [e]
else:
insns = [e]
seen = set(e.children())
todo = list(e.children())
while todo:
e = todo.pop()
# if smt.is_const(e) and not kd.utils.is_value(e):
# args.append(e)
if smt.is_var(e):
pass
elif smt.is_const(e):
continue
else:
insns.append(e)
for arg in e.children():
if arg not in seen:
seen.add(arg)
todo.append(arg)
insns.reverse()
return cls(sig=sig, insns=insns)
def vname(self, e: smt.ExprRef) -> str:
# if any(e.eq(v) for v in self.args):
# return str(e)
if smt.is_var(e):
return f"%var{smt.get_var_index(e)}"
elif smt.is_const(e):
return str(e)
else:
for i, insn in enumerate(self.insns):
if e.eq(insn):
return f"%{i}"
else:
raise ValueError(f"Value {e} not found in block")
def __repr__(self) -> str:
# res = [f"^({",".join(str(arg) for arg in self.args)})"]
res = [f"^({','.join(pp_sort(s) for s in self.sig)}):"]
for i, insn in enumerate(self.insns):
if isinstance(insn, smt.BitVecRef) and smt.is_bv_value(insn):
rhs = str(insn) + f":{insn.size()}"
elif kd.utils.is_value(insn):
rhs = str(insn)
else:
rhs = f"{insn.decl().name()} {", ".join(self.vname(arg) for arg in insn.children())}"
res.append(f"\t%{i} = {rhs}")
return "\n".join(res)
def succ_calls(self) -> list[smt.ExprRef]:
jmp = self.insns[-1]
if smt.is_if(jmp):
return jmp.children()
else:
return [jmp]
type Label = str
@dataclass
class Function:
""" """
entry: Label # smt.FuncDeclRef?
blocks: dict[Label, Block] # 0th block is entry. Or "entry" is entry? Naw. 0th.
@classmethod
def of_defined_funs(cls, funs: list[smt.FuncDeclRef]):
blocks = {f.name(): Block.of_defined_fun(f) for f in funs}
entry = funs[0].name()
return cls(entry=entry, blocks=blocks)
def calls_of(self) -> dict[str, list[tuple[Label, smt.ExprRef]]]:
"""
Returns a mapping from labels to a list of calls to that label
"""
p = defaultdict(list)
for label, blk in self.blocks.items():
for call in blk.succ_calls():
p[call.decl().name()].append((label, call))
return p
def phis(self):
"""
Return the analog a mapping from labels to phi nodes in that block
"""
preds = self.calls_of()
phis = {}
for label, blk in self.blocks.items():
phis[label] = zip(*[call.children() for _, call in preds[label]])
return phis
def __repr__(self) -> str:
res = [f"fn {self.entry}" + "{"]
for label, blk in self.blocks.items():
res.append(f"@{label}:")
res.append(str(blk))
res.append("}")
return "\n".join(res)
@dataclass
class Spec:
pre: dict[str, smt.BoolRef] = field(default_factory=dict)
post: dict[str, smt.BoolRef] = field(default_factory=dict)
cut: dict[str, smt.BoolRef] = field(default_factory=dict)
# def sym_exec():
Bottom = smt.DeclareSort("Bottom")
ret64 = smt.Function("ret64", smt.BitVecSort(64), Bottom)