[go: up one dir, main page]

Skip to content

Commit

Permalink
seq theorems, datatype induction
Browse files Browse the repository at this point in the history
  • Loading branch information
philzook58 committed Nov 8, 2024
1 parent 79d3bcf commit 29b1563
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 31 deletions.
21 changes: 18 additions & 3 deletions kdrag/notation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Importing this module will add some syntactic sugar to smt.
"""
The `SortDispatch` system enables z3 sort based dispatch akin to ` functools.singledispatch`.
This is the mechanism for operator overloading in knuckledragger.
A special overloadable operation is the "well-formed" predicate `wf`.
Using the QForAll and QExists quantifiers will automatically insert `wf` calls for the appropriate sorts.
In this way, we can achieve an effect similar to refinement types.
Importing this module will add some syntactic sugar to smt.
- Expr overload by single dispatch
- Bool supports `&`, `|`, `~`
- Sorts supports `>>` for ArraySort
- Datatypes support accessor notation
- Datatypes support accessor notation `l.is_cons`, `l.hd`, `l.tl` etc.
"""

import kdrag.smt as smt
Expand Down Expand Up @@ -178,7 +186,9 @@ def datatype_call(self, *args):

def Record(name, *fields, pred=None):
"""
Define a record datatype
Define a record datatype.
The optional argument `pred` will add a well-formedness condition to the record
giving something akin to a refinement type.
"""
if name in records:
raise Exception("Record already defined", name)
Expand Down Expand Up @@ -211,6 +221,11 @@ def NewType(name, sort, pred=None):


def cond(*cases, default=None) -> smt.ExprRef:
"""
Helper for chained ifs defined by cases.
Each case is a tuple of a bool condition and a term.
If default is not given, a check is performed for totality.
"""
sort = cases[0][1].sort()
if default is None:
s = smt.Solver()
Expand Down
26 changes: 26 additions & 0 deletions kdrag/theories/datatypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import kdrag.smt as smt
import kdrag as kd


def induct(DT: smt.DatatypeSortRef) -> kd.Proof:
"""Build a basic induction principle for an algebraic datatype"""
P = smt.FreshConst(smt.ArraySort(DT, smt.BoolSort()), prefix="P")
hyps = []
for i in range(DT.num_constructors()):
constructor = DT.constructor(i)
args = [
smt.FreshConst(constructor.domain(j), prefix="a")
for j in range(constructor.arity())
]
acc = P[constructor(*args)]
for arg in args:
if arg.sort() == DT:
acc = kd.QForAll([arg], P[arg], acc)
else:
acc = kd.QForAll([arg], acc)
hyps.append(acc)
x = smt.FreshConst(DT, prefix="x")
conc = kd.QForAll([x], P[x])
return kd.axiom(
smt.ForAll([P], smt.Implies(smt.And(hyps), conc)), by="induction_axiom"
)
File renamed without changes.
52 changes: 51 additions & 1 deletion kdrag/theories/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
Built in smtlib theory of finite sequences.
"""
# TODO: seq needs well formedness condition inherited from elements


def induct(T: smt.SortRef, P) -> kd.kernel.Proof:
Expand All @@ -19,3 +18,54 @@ def induct(T: smt.SortRef, P) -> kd.kernel.Proof:
) # -------------------------------------------------
== kd.QForAll([x], P(x))
)


class Seq:
def __init__(self, T):
self.T = T
sort = smt.SeqSort(T)
empty = smt.Empty(sort)
self.empty = empty
x, y, z = smt.Consts("x y z", sort)
# TODO: seq needs well formedness condition inherited from elements

self.concat_empty = kd.kernel.lemma(
kd.QForAll([x], smt.Concat(smt.Empty(sort), x) == x)
)
self.empty_concat = kd.kernel.lemma(
kd.QForAll([x], smt.Concat(x, smt.Empty(sort)) == x)
)
self.concat_assoc = kd.kernel.lemma(
kd.QForAll(
[x, y, z],
smt.Concat(x, smt.Concat(y, z)) == smt.Concat(smt.Concat(x, y), z),
)
)
self.concat_length = kd.kernel.lemma(
kd.QForAll(
[x, y], smt.Length(smt.Concat(x, y)) == smt.Length(x) + smt.Length(y)
)
)
self.length_empty = kd.kernel.lemma(
kd.QForAll([x], (smt.Length(x) == 0) == (x == empty))
)
"""
self.contains_concat_left = kd.kernel.lemma(
kd.QForAll(
[x, y, z], smt.Contains(x, z) == smt.Contains(smt.Concat(x, y), z)
)
)
self.contains_concat_right = kd.kernel.lemma(
kd.QForAll(
[x, y, z], smt.Contains(y, z) == smt.Contains(smt.Concat(x, y), z)
)
)"""
# self.contains_unit = kd.kernel.lemma(
# kd.QForAll([x, y], smt.Contains(smt.Unit(x), y) == (x == y))
# )
"""
self.contains_empty = kd.kernel.lemma(
kd.QForAll([x], smt.Contains(smt.Empty(T), x) == (x == smt.Empty(T)))
)"""
# InRe, Extract, IndexOf, LastIndexOf, prefixof, replace, suffixof
# SeqMap, SeqMapI, SeqFoldLeft, SeqFoldLeftI
22 changes: 0 additions & 22 deletions kdrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,28 +423,6 @@ def lemma_db():
return db


def induct(DT: smt.DatatypeSortRef):
"""Build a basic induction principle for an algebraic datatype"""
P = smt.FreshConst(smt.ArraySort(DT, smt.BoolSort()), prefix="P")
hyps = []
for i in range(DT.num_constructors()):
constructor = DT.constructor(i)
args = [
smt.FreshConst(constructor.domain(j), prefix="a")
for j in range(constructor.arity())
]
acc = P[constructor(*args)]
for arg in args:
if arg.sort() == DT:
acc = kd.QForAll([arg], P[arg], acc)
else:
acc = kd.QForAll([arg], acc)
hyps.append(acc)
x = smt.FreshConst(DT, prefix="x")
conc = kd.QForAll([x], P[x])
return smt.ForAll([P], smt.Implies(smt.And(hyps), conc))


import os
import glob
import inspect
Expand Down
11 changes: 6 additions & 5 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
import kdrag.smt as smt

import kdrag as kd
import kdrag.theories.nat
import kdrag.theories.datatypes.nat
import kdrag.theories.int
import kdrag.theories.real as R
import kdrag.theories.bitvec as bitvec
import kdrag.theories.complex as complex
import kdrag.theories.algebra.group as group

import kdrag.theories.datatypes as datatypes
import re

if smt.solver != smt.VAMPIRESOLVER:
import kdrag.theories.interval

import kdrag.theories.seq as ThSeq
import kdrag.theories.seq as seq

from kdrag import Calc
import kdrag.utils as utils
Expand Down Expand Up @@ -123,7 +123,8 @@ def test_record():


def test_seq():
ThSeq.induct(smt.IntSort(), lambda x: x == x)
seq.induct(smt.IntSort(), lambda x: x == x)
seq.Seq(smt.IntSort())


"""
Expand Down Expand Up @@ -221,7 +222,7 @@ def test_induct():
List.declare("nil")
List.declare("cons", ("head", smt.IntSort()), ("tail", List))
List = List.create()
assert kd.utils.induct(List) != None
assert datatypes.induct(List) != None


# TODO: test unsound
Expand Down

0 comments on commit 29b1563

Please sign in to comment.