Bellmaniac Compiler

Kuat Yessenov

Domain: Cache-oblivious algorithms

... is an algorithm designed to take advantage of a CPU cache without having the size of the cache as a parameter

Example: Strassen's algorithm for matrix multiplication.

Typically work by subdividing problem recursively until an instance fits into a cache.

Cache-oblivious model:

  1. RAM machine model augmented with a second-level cache of size $M >= B^2$ for words of size $B$
  2. Cache miss has a cost (cache complexity)
  3. Goal 1: maintain optimal running time while minimizing cache complexity
  4. Goal 2: optimal for any $M$ and $B$

Example: Parenthesis problem (Yao'1980)

Compute the minimum cost of parenthesizing $n$ elements given binary product as a primitive (e.g. concatenation of sets of strings)

Input

$$ c[i, i+1] = x_j $$ $$ c[i, j] = \min_{i \lt k \lt j}\left\{ c[i, k] + c[k, j] + w(i, k, j) \right\} $$

Output

autogenerated code

Why deductive?

  1. correct-by-construction
  2. automated performance analysis
  3. automated lowering
  4. opportunity for synthesis

Compiler overview

Problem definition

DSL for recurrence relations

Refinement script

Tactics for stepwise refinement

Lowering

Compilation to low-level imperative code (we use Python+numpy)

Scala DSL: Parenthesis problem

$$ c[i, j] = \min_{i \lt k \leq j}\left\{ c[i, k] + c[k, j] + w(i, k, j) \right\} $$


val par = Algorithm(c, i :: j :: Nil,
// pre-condition
0 <= i and i < n and i < j and j < n,
// recursive definition
IF ((i === j-1) -> x(i)) 
ELSE
  Reduce(c(i, k) + c(k, j) + w(i, k, j) 
  where k in Range(i+1, j)))

// add to the environment
input(w); input(x); input(n)
add(par, j-i)

Functions are state-less, provably terminating, and have only integer operations and monadic operators (reduce, zero, plus)

Proof environment

A collection of mutually recursive functions together with:

  • a refinement relation: \(A \sqsubseteq B\) if A can be replaced by B in any context while preserving semantics.
  • a restriction relation: pre-condition strengthening

Initially: \(\{par\}\)

Goal: \(\{A_{par}, B_{par}, C_{par}, \ldots\}\)

Full derivation of parenthesis

Refinement script

A sequence of tactic applications over functions in the environment. Each tactic produces a new function that refines the input function (and may spawn new functions).

Tactic must verify correctness of the refinement. Environment is consistent at every step.

Example


val R = Algorithm(r, List(i, j), par.pre, 
IF ((i === j-1) -> x(i)) ELSE Zero)
add(R, 0)

// first manually change expression of c to use r 
// $ just makes a fresh name for a variable
val par0 = manual($, 
// Op is the monadic operation used in Reduce
Op(Reduce(c(i, k) + c(k, j) + w(i, k, j) 
  where k in Range(i+1, j)), r(i, j)),
// tell prover ($$ is meant for prover) to unfold r above
$$.unfold($, R))(par)

SMT prover technology

Supports: SAT, ILP, uninterpreted function symbols. Limited expressive power but reliable.

Missing: inductive reasoning, folding/unfolding, modeling of higher-order functions, axiomatization of monadic operators, synthesis, ...

Think reasoning about symbolic expressions modulo SAT/integers

Tactic repertoire

Basic

  • manual: rewrite algorithm body and prove equivalence
  • unfold: unfold a definition of an algorithm
  • splitRange: divide list comprehension into two parts
  • guard: add an "if" case for a special value
  • relax: generalize pre-condition
  • introduce: add parameters

Contextual

  • selfRefine, refine: replace call to a function by its refinement
  • specialize: replace call to a function by its restriction

Complex

  • genApp: generalize function use to a parameter
  • split: generate restrictions based on a partitioning scheme
  • rewrite: rewrite function arguments and prove by induction

Rewrite example

Tactic:


val d110 = rewrite("d110", D, 
$$.splitRange($, k, n/4), $$.unfold($, D))(
i->(i-n/4), n->n/2, 
r->D.gen(2)(i, j, n/2, w>>(n/4,0,0), r>>(n/4,0), s>>(n/4,0), t),
w->(w>>(n/4,n/4,0)),
s->(s>>(n/4,n/4)),
t->(t>>(n/4,0))
)(d010)          

Input and output:


def d010(i, j, n, w, r, s, t):
assert (((((0 <= i) and (i < n/2)) and (0 <= j)) and (j < n/2)) and ((not (i < n/4)) and (j < n/4)))
return d0(i, j, n, w, r, s, t)
def d110(i, j, n, w, r, s, t):
assert (((((0 <= i) and (i < n/2)) and (0 <= j)) and (j < n/2)) and ((not (i < n/4)) and (j < n/4)))
return d0((i - n/4), j, n/2, (lambda _v203, _v204, _v205: w((_v203 + n/4), (_v204 + n/4), _v205)), (lambda i201, j202: d0(i201, j202, n/2, (lambda _v194, _v195, _v196: w((_v194 + n/4), _v195, _v196)), (lambda _v197, _v198: r((_v197 + n/4), _v198)), (lambda _v199, _v200: s((_v199 + n/4), _v200)), t)), (lambda _v206, _v207: s((_v206 + n/4), (_v207 + n/4))), (lambda _v208, _v209: t((_v208 + n/4), _v209)))

Proof structure: unfold \(d_0\) once, apply induction hypothesis in the input, use hints* to prove symbolic equivalence.

Analysis of tactic applications

Algorithm rewrite split splitRange* time (s) # steps
Floyd 17 2 0 11s 70
Gap 12 3 12 20s 40
Parenthesis 10 3 9 2s 40

Synthesis potential

Improve prover to automate basic tactics (e.g. unfold, guard).

Improve synthesizer to invoke contextual tactics given the current environment.

Synthesize complex tactics


val List(c1, c000, c001, c011) = 
split("c1", n < 4, i < n/2, j < n/2)(c0) 
val b100 = rewrite("b100", b0, 
$$.splitRange($, Var("k1"), n/4),
$$.unfold($, D))(
i->i,
j->(j-n/4),
n->n/2,        
w->(w>>(0,n/4,n/4)),
w1->(w1>>(0,0,n/4)),        
t->(t>>(n/4,n/4)),
// make d a function of i, j and pass the following arguments to d
r->D.gen(2)(i, j-n/4, n/2, w1>>(0,n/4,n/2), 
r>>(0,n/2),s>>(0,n/4),bij>>(n/4,n/2))          
)(b000)

Lowering of proof environment to low-level code

Input: collection of mutually recursive math functions.

Output: sequential code without \(\lambda\)s

Lowering phases

  • Full refinement, inlining: push all algorithms to the finest level
  • Add offsets as parameters: $$ \lambda i j . f(i + i_0, j + j_0) $$
  • Common subexpression elimination: reuse computation across split partitions
  • Memory read/write bound analysis: use SMT to compute loop bounds, etc

Projection of \(c_0\) to \(i, j\)


def c0(T, oi, oj, n, w, r, w_0, w_1, w_2, r_0, r_1):
for i0 in xrange(((-1)*n + 1),1):
  for j0 in xrange(((-1)*i0 + 1),n):
    i = (0 - i0)
    j = j0
    T[(i + oi), (j + oj)] = ...

Projection of \(c_1\) to \(i, j\)


def c1(T, oi, oj, n, w, r, w_0, w_1, w_2, r_0, r_1):
if (n < 4):
  c0(T, oi, oj, n, w, r, w_0, w_1, w_2, r_0, r_1)
return
  c1(T, oi, oj, n/2, w, r, w_0, w_1, w_2, r_0, r_1)
  c1(T, (oi + n/2), (oj + n/2), n/2, w, r, (n/2 + w_0), (n/2 + w_1), (n/2 + w_2), (n/2 + r_0), (n/2 + r_1))
  b1(T, oi, oj, n, w, r, T, T, w, w_0, w_1, w_2, r_0, r_1, oi, oj, oi, oj, w_0, w_1, w_2)