... is an algorithm designed to take advantage of a CPU cache without having the size of the cache as a parameter
Example: Strassen's algorithm for matrix multiplication.
Typically work by subdividing problem recursively until an instance fits into a cache.
Cache-oblivious model:
Compute the minimum cost of parenthesizing $n$ elements given binary product as a primitive (e.g. concatenation of sets of strings)
DSL for recurrence relations
Tactics for stepwise refinement
Compilation to low-level imperative code (we use Python+numpy)
$$ c[i, j] = \min_{i \lt k \leq j}\left\{ c[i, k] + c[k, j] + w(i, k, j) \right\} $$
val par = Algorithm(c, i :: j :: Nil,
// pre-condition
0 <= i and i < n and i < j and j < n,
// recursive definition
IF ((i === j-1) -> x(i))
ELSE
Reduce(c(i, k) + c(k, j) + w(i, k, j)
where k in Range(i+1, j)))
// add to the environment
input(w); input(x); input(n)
add(par, j-i)
Functions are state-less, provably terminating, and have only integer operations and monadic operators (reduce, zero, plus)
A collection of mutually recursive functions together with:
Initially: \(\{par\}\)
Goal: \(\{A_{par}, B_{par}, C_{par}, \ldots\}\)
Full derivation of parenthesisA sequence of tactic applications over functions in the environment. Each tactic produces a new function that refines the input function (and may spawn new functions).
Tactic must verify correctness of the refinement. Environment is consistent at every step.
val R = Algorithm(r, List(i, j), par.pre,
IF ((i === j-1) -> x(i)) ELSE Zero)
add(R, 0)
// first manually change expression of c to use r
// $ just makes a fresh name for a variable
val par0 = manual($,
// Op is the monadic operation used in Reduce
Op(Reduce(c(i, k) + c(k, j) + w(i, k, j)
where k in Range(i+1, j)), r(i, j)),
// tell prover ($$ is meant for prover) to unfold r above
$$.unfold($, R))(par)
Supports: SAT, ILP, uninterpreted function symbols. Limited expressive power but reliable.
Missing: inductive reasoning, folding/unfolding, modeling of higher-order functions, axiomatization of monadic operators, synthesis, ...
Think reasoning about symbolic expressions modulo SAT/integers
Tactic:
val d110 = rewrite("d110", D,
$$.splitRange($, k, n/4), $$.unfold($, D))(
i->(i-n/4), n->n/2,
r->D.gen(2)(i, j, n/2, w>>(n/4,0,0), r>>(n/4,0), s>>(n/4,0), t),
w->(w>>(n/4,n/4,0)),
s->(s>>(n/4,n/4)),
t->(t>>(n/4,0))
)(d010)
Input and output:
def d010(i, j, n, w, r, s, t):
assert (((((0 <= i) and (i < n/2)) and (0 <= j)) and (j < n/2)) and ((not (i < n/4)) and (j < n/4)))
return d0(i, j, n, w, r, s, t)
def d110(i, j, n, w, r, s, t):
assert (((((0 <= i) and (i < n/2)) and (0 <= j)) and (j < n/2)) and ((not (i < n/4)) and (j < n/4)))
return d0((i - n/4), j, n/2, (lambda _v203, _v204, _v205: w((_v203 + n/4), (_v204 + n/4), _v205)), (lambda i201, j202: d0(i201, j202, n/2, (lambda _v194, _v195, _v196: w((_v194 + n/4), _v195, _v196)), (lambda _v197, _v198: r((_v197 + n/4), _v198)), (lambda _v199, _v200: s((_v199 + n/4), _v200)), t)), (lambda _v206, _v207: s((_v206 + n/4), (_v207 + n/4))), (lambda _v208, _v209: t((_v208 + n/4), _v209)))
Proof structure: unfold \(d_0\) once, apply induction hypothesis in the input, use hints* to prove symbolic equivalence.
Algorithm | rewrite | split | splitRange* | time (s) | # steps |
---|---|---|---|---|---|
Floyd | 17 | 2 | 0 | 11s | 70 |
Gap | 12 | 3 | 12 | 20s | 40 |
Parenthesis | 10 | 3 | 9 | 2s | 40 |
Improve prover to automate basic tactics (e.g. unfold, guard).
Improve synthesizer to invoke contextual tactics given the current environment.
Synthesize complex tactics
val List(c1, c000, c001, c011) =
split("c1", n < 4, i < n/2, j < n/2)(c0)
val b100 = rewrite("b100", b0,
$$.splitRange($, Var("k1"), n/4),
$$.unfold($, D))(
i->i,
j->(j-n/4),
n->n/2,
w->(w>>(0,n/4,n/4)),
w1->(w1>>(0,0,n/4)),
t->(t>>(n/4,n/4)),
// make d a function of i, j and pass the following arguments to d
r->D.gen(2)(i, j-n/4, n/2, w1>>(0,n/4,n/2),
r>>(0,n/2),s>>(0,n/4),bij>>(n/4,n/2))
)(b000)
Input: collection of mutually recursive math functions.
Output: sequential code without \(\lambda\)s
def c0(T, oi, oj, n, w, r, w_0, w_1, w_2, r_0, r_1):
for i0 in xrange(((-1)*n + 1),1):
for j0 in xrange(((-1)*i0 + 1),n):
i = (0 - i0)
j = j0
T[(i + oi), (j + oj)] = ...
def c1(T, oi, oj, n, w, r, w_0, w_1, w_2, r_0, r_1):
if (n < 4):
c0(T, oi, oj, n, w, r, w_0, w_1, w_2, r_0, r_1)
return
c1(T, oi, oj, n/2, w, r, w_0, w_1, w_2, r_0, r_1)
c1(T, (oi + n/2), (oj + n/2), n/2, w, r, (n/2 + w_0), (n/2 + w_1), (n/2 + w_2), (n/2 + r_0), (n/2 + r_1))
b1(T, oi, oj, n, w, r, T, T, w, w_0, w_1, w_2, r_0, r_1, oi, oj, oi, oj, w_0, w_1, w_2)