Pages

Friday, January 23, 2015

Dependent typing from a C++ perspective

Coq has a feature called dependent typing. In Coq, among other things, dependent types are used to represent propositions (mathematical statements) and proofs of those propositions. This post explains how propositions are represented as types in Coq, and what happens if same tricks are applied in C++ (the short answer is that it doesn't work).

Hints to the power of static typing

First let me try to convince you that maybe it is possible to represent proofs using static typing. Unless you have heard of the Curry-Howard correspondence, the idea of representing a proof as a type probably sounds crazy. Proofs can be incredibly complex, and at first glance it doesn't seem like there is enough complexity in a static typing system to represent that.

For a C++ developer, the rules for static typing are complicated, but once they are built into the compiler, it doesn't feel like the compiler is really doing a lot of computation to apply them. It seems like you could almost find an upper bound for the amount of computation the compiler does for any given expression.

Many times I've found quickly found a class I want to use in the Java API docs, but then I spend an hour trying to figure out how to even create an instance of the class. The class I want is built through a static factory method. The factory method wants 2 other classes instantiated first. The constructors for those classes want other objects built first, and so on. This gives a hint at how static typing can get complex.

With Java the complexity is pretty much capped by the number of classes/interfaces in the library. The static typing system gets even more powerful with C++ templates. Now new types can be recursively built, and even used to do things like template metaprogramming.

Informal proof

We'll use an example proof to show how dependent typing works. The statement uses natural numbers, which is an integer greater than or equal to 0 (such as 0, 1, 2, etc). The statement to prove is really simple, it is: $\text{forall } n (\text{where $n$ is a natural number}), 1 \leq n + 1$.

Now let's to prove that statement with a traditional proof using induction (if you don't know what induction is, it will make sense soon). For understandability, this proof is written in a very explicit manner, which unfortunately makes the proof much longer.

Let $n$ be an arbitrary natural number. Either $n = 0$ or $n \geq 1$.
  • Let's take the first case where $n = 0$. This is called the base case. In this case, $1 \leq n + 1$. That simplifies to $1 \leq 0 + 1$, which is trivially true.
  • Next we prove the second case. This is called the inductive step. We prove the second case by proving that the statement is true when $n = n' + 1$, assuming that the statement is already proved for $n'$. Note that $n'$ is just a variable name; it isn't a derivative.

    Again, we are trying to prove $1 \leq n + 1$, which is the same as $1 \leq n' + 1 + 1$. Our assumption says that $1 \leq n' + 1$.

    Trivially, $n' + 1 \leq n' + 1 + 1$.
    So $1 \leq n' + 1 + 1$.
$\blacksquare$

We tend to learn about induction as a set pattern without really learning why it works. The induction proof can be thought of as a set of steps to generate a specific proof for any specific natural number, $n$.
  1. Use the base case to get a proof for when $a = 0$. If $n = 0$, use this proof and stop here.
  2. While $a \lt n$, apply the induction step to get a proof for $a := a + 1$.
Let's say that we want to use the proof for $1 \leq n + 1$, when $n = 3$. First the base case is used to generate a proof for $a = 0$. Next the induction step is run. The induction step uses the proof for $a = 0$ to generate a new proof for $a' = 0$. The induction step is run 2 more times until a proof is generated for $n= 3$.

The induction proof can be thought of as a short program that takes $n$ as an input parameter, then it unrolls the induction steps to generate a proof of the statement for $n$. The program returns that generated proof. Now the proof is starting to sound like a computer program!

Natural numbers

Before jumping into how the proof translates into C++, we need to be more formal about how natural numbers are defined. The unsigned integer types in C++ have an upper bound of $2^64$ or $2^128$ depending on the compiler. In math, the set of natural numbers does not have an upper bound, so we can't use C++'s built-in integer types to represent natural numbers.

Coq uses the Peano construction for natural numbers, so we'll use the same thing in C++. Note Peano is pronounced "pay-a-no" in native Italian but often bastardized to "piano" in English.

With peano numbers, $0$ is axiomatically defined (meaning it exists, but the concept can't be broken down further). Then there is a constructor (for now think function) that adds $1$ to any natural number. So 2 is represented as $0+1+1$.

In C++, those rules can be encoded using template metaprogramming. Of course the downside to using metaprogramming is that the natural numbers can only be defined at compile time.

// A generic base for all natural numbers
struct nat
{
};

// A type representing 0
struct nat_0 : public nat
{
};

// A template that adds 1 to a natural number.
// So 3 is defined as nat_S<nat_S<nat_S<nat_0> > >
template <class Pred>
struct nat_S : public nat
{
};

This post isn't about Coq syntax, but for the sake of completeness, I'll include the equivalent Coq code. The nat_0 function is simply called O and nat_S is called S.
Inductive nat : Set :=
  | O : nat
  | S : nat -> nat.

Pattern matching

In Coq, the S constructor (nat_S in C++) is a function in the sense that you can pass a natural number it a number and the constructor will return that number plus 1. However, the constructor is more than a regular function because pattern matching can be performed on the result. Pattern matching means that you can break apart a natural number to determine if it was constructed by calling 0 or by calling S. If it was constructed with S, you can determine what the parameter was for S.

Fortunately the pattern matching for the constructor maps well to the template structure we used for the C++ definitions. The pattern matching can be performed using template specialization or function overloading. In either case, one can determine whether a natural number was built with nat_0 or nat_S.

Let's use function overloading to define an equality operator for natural numbers.
// The compiler tries first to use this overload because the parameters are
// derived types (nat_0 instead of nat), and this overload does not use
// templates.
constexpr bool operator == (const nat_0 &, const nat_0 &)
{
    return true;
}

// If the first overload fails, the compiler tries this one next, because this
// overload uses a derived types for its parameters (nat_S instead of nat).
template <class Pred1, class Pred2>
constexpr bool operator == (const nat_S<Pred1> &, const nat_S<Pred2> &)
{
    return Pred1() == Pred2();
}

// Finally if nothing else works, the compiler will use this overload. This
// overload is least preferred because the compiler has to cast the parameters
// to the base type.
constexpr bool operator == (const nat &, const nat &)
{
    return false;
}

Coq's standard definition for equality is more complicated than the C++ example. Here is what the simplified example equality function maps to in Coq:
Fixpoint nat_eq (m n: nat) : bool :=
  match m,n with
    | O, O => true
    | S m', S n' => nat_eq m' n'
    | _, _ => false
  end.

As another example of pattern matching, we'll define the pred function. The pred function performs truncated decrement. This means if the input parameter is 1 or larger, the function subtracts 1 from the parameter. Otherwise 0 is returned. Also the pred function an the equality operator from above can be validated against each other.
// The compiler tries this overload first
constexpr nat_0 pred(const nat_0 &n)
{
    return n;
}

// The compiler tries this overload second
template <typename Pred>
constexpr Pred pred(const nat_S<Pred> &n)
{
    return Pred();
}

// Show that pred(0) returns 0
static_assert(pred(nat_0{}) == nat_0{}, "must be equal");
// Show that pred(1) != 1
static_assert(!(pred(nat_S<nat_0>{}) == nat_S<nat_0>{}), "must not be equal");

This time the C++ code was a direct translation of the standard Coq pred function.
Definition pred (n : nat) :=
  match n with
  | 0 => n
  | S p => p
  end.

Less-than-equal binary relation

So far natural numbers, incrementing, and decrementing have been precisely defined. The last definition used by the proof is $\le$. In math, $\le$ is called a binary relation because it describes a property between two operands (that is $m \leq n$).

In the C++ example, the $\leq$ relation is translated into a templated class called le. The two operands to compare are passed to the le class as template parameters.

In C++, <= is an operator that returns a true/false boolean. The le class does something quite different. Instead of returning a bool, the le<m,n> class can only be constructed if the proposition is true, namely m <= n.

Using public functions (constructors or factory method), there are only two ways to construct an instance of le<m,n>. An instance of le<n,n> can be constructed using the le_n() factory method. As the signature implies, this function can only build le<m,n> when m = n.

The other factory method, le_S() basically constructs le<m, n+1>, but it requires an instance of le<m, n>. So for example, le<1, 3> could constructed with these steps:
  1. Create le<1,1> using le_n()
  2. Create le<1,2> using le_S()
  3. Create le<1,3> using le_S()

// Can only instantiate the le class if m is less than or equal to n. m and n
// are peano numbers.
template<class m, class n>
class le
{
private:
    le()
    {
    }

    // There are only 2 ways to externally create the le class. The le_n case
    // covers when n = n.
    template<class o>
    friend le<o, o> le_n();

    // The le_S case lets one prove that n <= m + 1, but only if n <= m.
    template<class o, class p>
    friend le<o, nat_S<p> > le_S(const le<o, p> &);
};

template<class m>
le<m, m> le_n()
{
    return le<m, m>();
}

template<class m, class n>
le<m, nat_S<n> > le_S(const le<m, n> &)
{
    return le<m, nat_S<n> >();
}

If the proposition is false, namely m > n, there simply are no public constructors or factory functions to build le<m,n>.

Let's say m > n. You can still describe the type le<m,n>, but there is simply no way to instantiate it using public constructors or factory methods. A functions could have a parameter of type le<m,n>, but since you can't instantiate le<m,n>, you can't run the function.

Again, here is the equivalent Coq code:
Inductive le (m : nat) : nat -> Prop :=
  | le_n : le m m
  | le_S : forall n : nat, le m n -> le m (S n).

Formal Proof

Now we have the pieces to try implementing the proof in C++. Again the proof statement is $\text{forall } n (\text{where $n$ is a natural number}), 1 \leq n + 1$.

As discussed under the informal proof section, because of the $\text{forall } n$ part of the proposition, this proof can be thought of as a proof generator: a function that takes $n$ as a parameter and returns a proof for that specific $n$.

In C++ we'll call this function le_plus_1. The induction is implemented using pattern matching through function overloading.

Let's start with the base case where $n = 0$. For the base case, the statement simplifies as follows:
\[ \begin{aligned}
  1 &\leq n + 1\\
  1 &\leq 0 + 1\\
  1 &\leq 1\\
\end{aligned} \]
Now the base case can be implemented in C++ as:
le<nat_S<nat_0>, nat_S<nat_0>> le_plus_1(const nat_0 &)
{
    return le_n<int_to_nat<1>::type>();
}
The induction step has to prove $1 \leq (n + 1) + 1$, which translates to instantiating this type: \[
  \text{le<}1, (n + 1) + 1\text{>}\\
  \text{le< nat_S<nat_0>, nat_S<nat_S<$n$>> >}
\] The induction step is given $1 \leq n + 1$ as an assumption. In C++, obtaining the assumption translates into using recursion. So le_plus_1($n+1$) calls le_plus_($n$).
template <typename N>
le<nat_S<nat_0>, nat_S<nat_S<N>> > le_plus_1(const nat_S<N> &n)
{
    return le_S(le_plus_1(pred(n)));
}

In Coq, the proof looks like:
Lemma le_plus_1 : forall n:nat, 1 <= n + 1.
Proof.
  induction n.
  + (* Base case. Prove 1 <= 0 + 1 *)
    exact (le_n 1).
  + (* Induction case. Prove 1 <= S n + 1, given the assumption as IHn *)
    apply le_S.
    exact IHn.
Qed.

How the proof breaks in C++

The le_plus_1 function acts like a proof generator. Given a specific $n$, it returns a proof for that $n$. The problem is that le_plus_1 is written as a template function and template use adhoc polymorphism. That means the C++ compiler does not validate the function until it is instantiated (given specifics value for its template parameters).

So le_plus_1 generates proofs for specific values of n, and the compiler validates those generated proofs. The compiler does not validate that le_plus_1 works forall n. It only validates the specific values of n that the proof is instantiated with.

Hypothetically the base case could be tested and correct, but the induction step could be untested and incorrect. To validate the proof, each branch in the proof needs to be validated with its own test case. Granted that once the correct input values are determined to test the branch, the validation is simple: make sure it compiles.

On the other hand, Coq uses dependent typing instead of adhoc polymorphism. One advantage is that the Coq compiler automatically validates that le_plus_1 works for all values of n, without having to give test cases.

With C++, all template metaprogramming is run at compile time. There is no way to let the user input an integer, then convert that to the proper instance of our nat class. With Coq, new nat can be created at runtime, and they can be passed to functions that use dependent typing.

No comments:

Post a Comment