kdrag.printers.rust

Functions

compile_rust(fun_name, fun_code[, dir])

init_proj([proj_path])

of_sort(s)

rust_module_template(modname, fun_name, fun_code)

kdrag.printers.rust.compile_rust(fun_name, fun_code, dir='/tmp/kdrag_rust')
kdrag.printers.rust.init_proj(proj_path='/tmp/kdrag_rust')
kdrag.printers.rust.of_sort(s: SortRef)
Parameters:

s (SortRef)

kdrag.printers.rust.rust_module_template(modname: str, fun_name: str, fun_code: str)
Parameters:
  • modname (str)

  • fun_name (str)

  • fun_code (str)

import kdrag.smt as smt
import subprocess
import importlib
import os


def of_sort(s: smt.SortRef):
    if s == smt.BoolSort():
        return "bool"
    if isinstance(s, smt.BitVecSortRef):
        if s.size() in [8, 16, 32, 64]:
            return f"u{s.size()}"
        else:
            raise NotImplementedError("No support for arbitrary C int sizes", s.size())
    else:
        raise NotImplementedError(f"Cannot convert {s} to C type")


default_dir = "/tmp/kdrag_rust"


def init_proj(proj_path=default_dir):
    cargofile = os.path.join(proj_path, "Cargo.toml")
    res = subprocess.run(["cargo", "init", "--lib", proj_path], capture_output=True)
    if res.returncode != 0:
        print(res.stderr.decode())
        print(res.stdout.decode())
        raise RuntimeError("Failed to initialize cargo project")
    res = subprocess.run(
        [
            "cargo",
            "add",
            "pyo3",
            "--features",
            "extension-module",
            "--manifest-path",
            cargofile,
        ],
        capture_output=True,
    )
    if res.returncode != 0:
        print(res.stderr.decode())
        print(res.stdout.decode())
        raise RuntimeError("Failed to add pyo3 dependency")


def rust_module_template(modname: str, fun_name: str, fun_code: str):
    return f"""\
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

#[pyfunction]
{fun_code}

#[pymodule]
fn {modname}(m: &Bound<'_, PyModule>) -> PyResult<()> {{
    m.add_function(wrap_pyfunction!({fun_name}, m)?)?;
    Ok(())
}}
"""


def compile_rust(fun_name, fun_code, dir=default_dir):
    mod_name = os.path.basename(dir)
    cargofile = os.path.join(dir, "Cargo.toml")
    rs_file = os.path.join(dir, "src", "lib.rs")
    with open(rs_file, "w") as f:
        f.write(rust_module_template(mod_name, fun_name, fun_code))
    # Compile Rust code into a shared object
    res = subprocess.run(["maturin", "develop", "-m", cargofile], capture_output=True)
    if res.returncode != 0:
        print(res.stderr.decode())
        print(res.stdout.decode())
        raise RuntimeError("Failed to compile Rust code")
    return importlib.import_module(mod_name)