Abstract
Well-founded recursion generalises both strong induction and recursion. In this post we’ll:
- Discuss the (informal) notion of a well-founded relation.
- Show how to use well-founded relations (initially just with the
<
relation on ℕ) to define recursive functions and prove theorems- using the equation compiler and
- by writing proof terms.
- Explain why
rfl
proofs don’t work for functions defined by well-founded recursion. - Employ sneaky equation compiler tricks to simplify function writing.
- Explain how to use the
using_well_founded
command- to work with custom relations via
rel_tac
and - to create decreasing proofs via
dec_tac
.
- to work with custom relations via
In subsequent posts, I plan to (1) use well-founded recursion to show that the quicksort algorithm sorts a list and (2) discuss the formal definition of well_founded
.
The idea of well-founded recursion
In our post on recursors, we saw that the standard way to define a function in Lean is via structural recursion. For instance, to define a function f
on ℕ
, we must define f 0
and, for each n : ℕ
, we must specify how f (n+1)
depends on f n
.
Not all functions are most naturally defined this way. Consider a log-like function function \(\mathrm{lg} : \mathbb N \to \mathbb N\) defined so that
$$
\mathrm{lg}(n)=
\begin{cases}
0, & \text{ if } n = 0, \\
1 + \mathrm{lg}(n/2), & \text{if } 0 < n.
\end{cases}
$$
Thus, recalling that we round when perform division on \(\mathbb N\),
$$
\mathrm{lg}(7) = 1 + \mathrm{lg}(3) = 1 + (1 + \mathrm{lg}(1)) = 1 + (1 + (1 + \mathrm{lg}(0))) = 3.
$$
This computation terminates because it depends on computing \(\mathrm{lg}(n)\) for each \(n\) in the finite decreasing chain \(7 > 3 > 1 > 0\).
More generally, every decreasing chain of natural numbers is finite. We say that the \(<\) relation on \(\mathbb N\) is well founded.
This isn’t (syntactically) the same as the definition of well founded given in Lean. We’ll come to the Lean definition in a later post.
Moreover, as we’ll see in this post, the chain of equalities above cannot be proved by reflexivity.
Not all relations are well founded. In particular, the \(<\) relation on \(\mathbb Z\) is not well founded: you can easily find an infinitely long decreasing sequence of integers.
Lean has user-friendly mechanisms for:
- Proving that a given relation is well founded.
- Using well-founded relations to define functions.
As we explore these, make sure you have import tactic data.nat.parity tactic.induction
at the top of your Lean file.
The equation compiler I
The equation compiler makes it (relatively) easy to write functions that depend on well-founded recursion.
So what exactly is the equation compiler? It’s the piece of software that takes declarations given using pattern matching, match expressions, etc. and compiles them down to Lean terms. For example, the function lg
below is compiled to:
def lg._main._pack : Π (x : ℕ), (λ (x : ℕ), ℕ) x :=
λ (x : ℕ),
has_well_founded.wf.fix
(λ (x : ℕ),
x.cases_on
(id_rhs ((Π (y : ℕ), has_well_founded.r y 0 → ℕ) → ℕ)
(λ (F : Π (y : ℕ), has_well_founded.r y 0 → ℕ), 0))
(λ (n : ℕ),
id_rhs ((Π (y : ℕ), has_well_founded.r y (n + 1) → ℕ) → ℕ)
(λ (F : Π (y : ℕ), has_well_founded.r y (n + 1) → ℕ),
have this : (n + 1) / 2 < n + 1, from _,
1 + F ((n + 1) / 2) _))) x
You can see this first by typing #print lg
, which shows:
def lg : ℕ → ℕ := lg._main
Likewise, #print lg._main
shows
def lg._main : ℕ → ℕ := λ (ᾰ : ℕ), lg._main._pack ᾰ
And #print lg._main._pack
gives the term at this top of this pop-out.
Here’s our function lg
:
def lg : ℕ → ℕ
| 0 := 0
| (x + 1) :=
have (x + 1) / 2 < (x + 1), from nat.div_lt_self' x 0,
1 + lg ((x + 1)/2)
lg
is defined so that lg 0 = 0
and lg (x + 1) = 1 + lg ((x + 1)/2)
for every natural number x
. But what’s happening with the have
line?
We’re defining lg (x + 1)
in terms of lg ((x + 1)/2
. But this will only lead to a finite computation if \((x + 1)/2 < x + 1\). The have
line provides Lean’s equation compiler with a proof of this fact.
If you remove the have
line, Lean complains, ‘failed to prove recursive application is decreasing’. At the bottom of the error message, there’s some useful information:
nested exception message:
default_dec_tac failed
state:
lg : ℕ → ℕ,
x : ℕ
⊢ (x + 1) / 2 < x + 1
The last line shows what’s needed to prove prove that the recursive application is decreasing.
Now we can use the virtual machine to compute values of lg
. For instance, #eval lg 32
gives 6 and #eval lg 270
gives 9. It looks like lg
is related to the base-2 logarithm function. We’ll make this precise and prove part of the assertion later.
This was a simple example, but in other situations, we may need to give the equation compiler more information. For one thing, we may want to define a function on a type other than ℕ
. In that case <
(more precisely, nat.lt
) will not be the correct relation. We’ll see later how to provide this information.
For now, let’s prove simple statements regarding lg
. Our first goal is to prove lg 0 = 0
. Were lg
defined by structural recursion, this result could easily be proved by rfl
, as in the example below:
def jane : ℕ → ℕ
| 0 := 2
| (x + 1) := 3 * jane x
example : jane 3 = 54 := rfl
But, for reasons that will be made clear in the next section, rfl
doesn’t give a proof that lg 0 = 0
. However, a rw
will work.
lemma lg_zero : lg 0 = 0 := by { rw lg }
lemma lg_one : lg 1 = 1 := by { erw [lg, lg_zero] }
Above, erw
is a more aggressive (more eager?) version of rw
. We use it here so that Lean understands lg ((0 + 1) / 2) = lg 0
.
Exercises
- Write a function
ex1
such that
$$
\mathrm{ex1}(n)=
\begin{cases}
1 ,& \text{if } n = 0, \\
n\times \mathrm{ex1}(n/3), & \text{otherwise}.
\end{cases}
$$ - Prove that
ex1 0 = 1
and thatex1 4 = 4
.
Well-founded recursion by hand
The equation compiler does a marvellous job of hiding the details of the construction. In this section, we’ll uncover some of these details. Feel free to skip to the next section if this is not of interest to you.
The function lg_by_hand
is extensionally equal to lg
.
def myF : Π (x : ℕ) (h : Π (y : ℕ), y < x → ℕ), ℕ
| 0 _ := 0
| (x + 1) h := 1 + h ((x + 1) / 2) (nat.div_lt_self' _ _)
def lg_by_hand := well_founded.fix nat.lt_wf myF
There are two parts to this definition. The actual computation is specified in myF
. Before we go into the theory, compare this definition with that of lg
above. The term h ((x + 1) / 2) (nat.div_lt_self' x 0)
combines both the expression lg ((x + 1) / 2)
and the inequality proved in the have
line of lg
.
The Lean function well_founded.fix
takes 1) a proof hwf
that a relation r
is well founded (in this case, the theorem nat.lt_wf
shows that <
is well founded) and 2) a function F : Π x, (Π y, r y x → C y) → C x
, where the motive C
defines the type of the function being defined. It returns a function that takes each x
to a term of type C x
. In our case, myF
takes the place of F
. The motive is λ (a : ℕ), ℕ
.
Crucially, the function returned by well_founded.fix
(which we’ll henceforth abbreviate to fix
) satisfies a ‘fixpoint equation’. Specifically,
fix hwf F x = F x (λ y h, fix hwf F y)
The theorem well_founded.fix_eq
proves this assertion. In our case, we have:
lemma lg_by_hand_eq :
∀ x, lg_by_hand x = myF x (λ y h, lg_by_hand y) := well_founded.fix_eq _ _
This makes clear the sense in which myF x _
encodes the value of lg_by_hand x
. We can use this theorem to compute lg_by_hand x
for particular x
.
lemma lg_zero_bh : lg_by_hand 0 = 0 := lg_by_hand_eq 0
lemma lg_one_bh : lg_by_hand 1 = 1 :=
by { erw [lg_by_hand_eq, myF, lg_zero_bh] }
Moreover, we can now see why rfl
cannot prove lg 0 = 0
(or lg_by_hand 0 = 0
). It’s because lg 0
isn’t definitionally equal to 0
! What we can prove with rfl
is the following:
example : lg_by_hand 0 = well_founded.fix nat.lt_wf myF 0 := rfl
Exercises
Repeat the exercises concerning ex1
from the previous section, but using well_founded.fix
to define your function. Take care not to re-use ex1
in your definitions!
Underhanded tricks
Our aim in writing lg
and is to give a function that behaves like the base-2 logarithm. But it’s out by 1. A better definition would be:
$$
\mathrm{lg_2}(n)=
\begin{cases}
0, & \text{ if } n \le 1, \\
1 + \mathrm{lg_2}(n/2), & \text{otherwise}.
\end{cases}
$$
In the definition of lg
, the argument split nicely into two cases: it’s either 0
or n+1
, for some natural number n
. This matches the inductive definition of ℕ
.
How do we use the equation compiler to give a definition of lg2
? One option is to use three patterns:
def lg2 : ℕ → ℕ
| 0 := 0
| 1 := 0
| (n + 2) :=
have h : (n + 2) / 2 = n / 2 + 1 :=
nat.add_div_of_dvd_left (dvd_refl 2),
have n / 2 + 1 < n + 2 := h ▸ nat.div_lt_self' _ _,
1 + lg2 ((n + 2) / 2)
This would get tedious if the first case of our (mathematical) function \(\mathrm{lg_2}\) were \(\mathrm{lg_2}(n) = 0\) for \(n \le 100\). It would be impossible if the first case were \(\mathrm{lg_2}(n) = 0\) for \(\mathrm{even}(n)\).
Trick 1
Our first underhanded trick (not a technical term) is to use one pattern! Below, we use the pattern n
to match any input.
def lg2' : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else
have n / 2 < n, from
nat.div_lt_self (by linarith) (nat.le_refl 2),
1 + lg2' (n / 2)
Trick 2
Everything we’ve done so far has been in term mode. Nothing in tactic mode. However, it’s fine to use tactic mode as long as the desired inequality is proved in term mode.
Fortunately, Lean permits (nested) switching between modes. The following examples illustrate the principle, though there is little to be gained by mode-switching in these simple cases.
In lg2''
, we prove the inequality in term mode then switch to tactic mode. The exact
tactic returns us to term mode. We provide the same term as before.
def lg2'' : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else
have n / 2 < n,
from nat.div_lt_self (by linarith) (nat.le_refl 2),
by { exact 1 + lg2'' (n / 2) }
By contrast, lg2'''
starts (after an initial pattern match and if
statement) in tactic mode. We immediately switch back to term mode via exact
.
def lg2''' : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else
begin
exact have n / 2 < n,
from nat.div_lt_self (by linarith) (nat.le_refl 2),
1 + lg2''' (n / 2),
end
Trick 3: dec_tac
In fact, we can separate out the proof of the using_well_founded
command and the dec_tac
field.
def lg2_iv : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else 1 + lg2_iv (n / 2)
using_well_founded
{ dec_tac :=
`[exact nat.div_lt_self (by linarith) (nat.le_refl 2)] }
dec_tac
is a field of the structure well_founded_tactics
. It should be a tactic—more precisely, a term of type tactic.unit
. The odd-looking backtick-bracket notation creates such a term from an ordinary ‘interactive mode’ tactic.
Via this construction, we can (if desired) push almost an entire proof or definition into tactic mode.
Exercises
- Using underhanded tricks, write a function
ex2
given by
$$
\mathrm{ex2}(n) =
\begin{cases}
0, & \text{if } n \text{ is odd or } n = 0, \\
1 + \mathrm{ex2}(n/2), & \text{otherwise}.
\end{cases}
$$ - Prove that
ex2 0 = 0
and thatex2 4 = 2
. - Do the same, but using
well_founded.fix
to define your function (you can still use underhanded tricks, albeit indirectly).
Proofs by well-founded recursion
As you may have guessed by playing with lg
, we expect:
$$\mathrm{lg(n)} = \lfloor\log_2(n)\rfloor + 1.$$
That is,
$$(n+1) < 2^{\mathrm{lg}(n+1)} \le 2(n+1),$$
for every natural number \(n\). In this section, we’ll prove the first of these inequalities. To start with, we’ll write the inequality as a predicate.
def lg_ineq (n : ℕ) : Prop := n + 1 < 2 ^ lg (n + 1)
Next we’ll give four proofs (!) of the desired result: one by hand and three using the equation compiler.
It transpires that we’ll need a couple of preliminary results:
lemma two_mul_succ_div_two {m : ℕ} : (2 * m + 1) / 2 = m :=
begin
rw [nat.succ_div, if_neg], norm_num,
rintros ⟨k, h⟩, exact nat.two_mul_ne_two_mul_add_one h.symm,
end
lemma two_mul_succ_succ {m : ℕ} : 2 * m + 1 + 1 = 2 * (m + 1) := by linarith
Proofs by hand
The ‘by hand’ proof uses well_founded.fix
. The result and its proof is
lemma lg_lemma : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1) := well_founded.fix nat.lt_wf lg_lemma_aux
Here, we require:
lg_lemma_aux :
∀ (x : ℕ), (∀ (y : ℕ), y < x → lg_ineq y) → lg_ineq x
To prove this result, we can case split on x
(considering whether we have a term of shape 0
or x + 1
) and then decompose the second case to whether x
is odd or even. Recall that h
‘carries’ with it proofs of the result for all y < x
.
lemma lg_lemma_aux (x : ℕ) (h : Π (y : ℕ), y < x → lg_ineq y) : lg_ineq x :=
begin
cases x,
{ rw [lg_ineq, lg_one], norm_num, }, -- base case
dsimp [lg_ineq] at h ⊢,
rcases nat.even_or_odd x with ⟨m, rfl⟩ | ⟨m, rfl⟩,
{ have h₄ : m < 2 * m + 1, by linarith,
specialize h m h₄, rw [nat.succ_eq_add_one, lg, pow_add],
rw two_mul_succ_succ, norm_num, exact h },
{ have h₄ : m < 2 * m + 1 + 1, by linarith,
specialize h m h₄, rw [lg, pow_add],
rw [two_mul_succ_succ, two_mul_succ_div_two], linarith },
end
Note the have h₄
terms (in green) above that are needed to justify that the recursive application is decreasing.
Equation compiler proof 1
Alternatively , we can use the equation compiler. Our first approach uses tactics nested inside with the decreasing proof outside.
lemma lg_lemma2 : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1)
| 0 := by { rw lg_one, norm_num, }
| (x + 1) := or.elim (nat.even_or_odd x)
( λ ⟨m, hm⟩,
have m < x + 1, by linarith, -- needed for wf recursion
begin
specialize lg_lemma2 m, rw [hm, lg, pow_add],
rw two_mul_succ_succ, norm_num, exact lg_lemma2,
end )
( λ ⟨m, hm⟩,
have m < x + 1, by linarith, -- needed for wf recursion
begin
specialize lg_lemma2 m, rw [hm, lg, pow_add],
rw [two_mul_succ_succ, two_mul_succ_div_two], linarith,
end )
If you aren’t fond of term mode, this proof may seem the most awkward as it uses or.elim
and pattern-matching lambda abstraction.
Equation compiler proof 2
The next proof is a bit nicer. We put everything into tactic mode, pushout out to term mode only for the decreasing proof and the actual application of recursion.
lemma lg_lemma2' : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1)
| 0 := by { rw lg_one, norm_num, }
| (x + 1) :=
begin
cases (nat.even_or_odd x),
{ rcases h with ⟨m, hm⟩,
rw [hm, lg, pow_add],
rw two_mul_succ_succ, norm_num,
exact have m < x + 1, by linarith,
lg_lemma2' m, },
{ rcases h with ⟨m, hm⟩,
rw [hm, lg, pow_add],
rw [two_mul_succ_succ, two_mul_succ_div_two],
exact have m < x + 1, by linarith,
show _, by { specialize lg_lemma2' m, linarith }, }
end
Note the weird construction on the penultimate line. We’re in term-mode here but push back into tactic mode using by
.
Equation compiler proof 3
My favourite proof has (almost) everything in tactic mode. The decreasing proof is postponed to the final line via using_well_founded
and dec_tac
.
lemma lg_lemma2'' : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1)
| 0 := by { rw lg_one, norm_num, }
| (x + 1) :=
begin
cases (nat.even_or_odd x),
{ rcases h with ⟨m, hm⟩, specialize lg_lemma2'' m,
rw [hm, lg, pow_add],
rw two_mul_succ_succ, norm_num, exact lg_lemma2'', },
{ rcases h with ⟨m, hm⟩,
specialize lg_lemma2'' m, rw [hm, lg, pow_add],
rw [two_mul_succ_succ, two_mul_succ_div_two], linarith }
end
using_well_founded
{ dec_tac := `[exact show m < x + 1, by linarith] }
It’s worth observing that, in this case, the tactic provided to dec_tac
supplants both decreasing proofs in each of the previous results.
Exercises
- Prove the corresponding upper bound result, viz. \(2^{\mathrm{lg}(n+1)} \le 2(n+1)\), for each \(n \in \mathbb N\). This should be easy.
- Determine and prove the correct lower bound result for the function
lg2
defined previously.
Using other relations with rel_tac
The <
relation on ℕ
isn’t always the correct well-founded relation for a given application. Indeed, we may not be defining a function on ℕ
!
A nice example is given in mathlib. For natural numbers n
and k
, min_fac_aux n k
is an auxiliary function used in the definition of the minimal prime factor of n
. It is a verified trial-division algorithm.
$$
m_f (n, k) =
\begin{cases}
n, & \text{if } n < k^2, \\
k, & \text{if } k \mid n, \\
m_f (n, k+2), & \text{otherwise}.
\end{cases}
$$
At first blush, it looks as though this defines a non-terminating function. The value of \(m_f(n, k)\) depends on the value of \(m_f(n, k + 2)\) but \(k + 2\) is greater than \(k\). Before we freak out, let’s try to calculate \(m_f(77, 3)\).
$$
m_f (77, 3) = m_f (77, 5) = m_f (77, 7) = 7.
$$
In this example, the ‘\(k\) value’ increased by 2 per step and the calculation terminated when we reached a value \(k\) for which $k \mid 77$. However, we cannot guarantee that such a \(k\) will be reached.
Here’s another example:
$$
m_f(11, 0) = m_f(11, 2) = m_f(11, 4) = 4
$$
The termination of the calculation above arises when we reach \(k\) such that \(n < k^2\).
Indeed, it’s clear that, no matter what the initial value of \(k\), by repeatedly adding \(2\) to \(k\), we will ultimately reach a situation where \(n < k^2\).
Using this observation, we’ll construct a well-founded relation.
Finding a well-founded relation
Our aim is to find a function \(f\) (which may depend on \(n\)) such that \(f(k +2) < f(k)\) for every \(k\). Intuitively, this shows that something is decreasing with every recursive application of \(m_f\).
This intuition is shored up by a theorem. Let f : α → ℕ
be a function on α
(called a measure in the context of well-founded relations). This function induces a relation ≺
on α
defined so that a ≺ b
means f a ≺ f b
. It’s a theorem (called measure_wf
in Lean) that any such relation will be well founded.
A first approach might be to try \(f(k) = 1 / k\). Surely \(1 / (k+2) < 1 / k\)? Actually no! As we’re dealing with natural number division, \(1 / k = 0\) for every non-zero \(k\). Though the induced relation is well-founded, it’s useless because we cannot prove the recursive application is decreasing.
What about \(-k\)? If we were working with integers and not natural numbers, this would work as \(-(k + 2) < -k\). Unfortunately, \(-k\) is \(0\) for every natural number \(k\).
However, all is not lost! Let’s take \(f(k) = \sqrt n – k + 2\). In our recursive application of \(m_f(n, k)\), we are guaranteed that \(n \ge k^2\). Thus, \(k \le \sqrt n\). From this, it follows that
$$
f(k + 2) = \sqrt n – k < \sqrt n + 2 – k = f(k).
$$
Putting this together, we have the following definition. Note the use of the ‘single pattern’ trick. We match every input argument with the single pattern k
.
open nat
def min_fac_aux (n : ℕ) : ℕ → ℕ | k :=
if h : n < k * k then n else
if k ∣ n then k else
have sqrt n - k < sqrt n + 2 - k, -- needed for wf recursion
{ rw nat.sub_lt_sub_right_iff,
{ exact lt_trans (lt_add_one _) (lt_add_one _) },
{ rw nat.le_sqrt, exact le_of_not_gt h } },
min_fac_aux (k + 2)
using_well_founded {
rel_tac := λ _ _,
`[exact ⟨_, measure_wf (λ k, sqrt n + 2 - k)⟩]}
A new character, rel_tac
, enters the stage. Here, as in most situations, we let Lean fill in the first two arguments by type inference. The last argument is a tactic whose purpose is to synthesize an instance of has_well_founded α
. That is, it must (1) provide a relation r
and (2) prove that r
is well founded. As with dec_tac
, we use the backtick-square bracket notation to produce a term of type tactic.unit
.
In the example above, the tactic is:
exact ⟨_, measure_wf (λ k, sqrt n + 2 - k)⟩
We use a wildcard _
to let Lean fill in the relation. As intimated earlier,
measure_wf (λ k, sqrt n + 2 - k)
is a proof that the relation induced by the function λ k, sqrt n + 2 - k
is well founded.
It’s worth noting that using_well_founded
is always at work in the background. If the user doesn’t supply a dec_tac
or rel_tac
field, the command falls back on default tactics.
Note also that we can supply both dec_tac
and rel_tac
fields.
Finally, here is how mathlib defines the min_fac
function using the auxiliary function above.
def min_fac : ℕ → ℕ
| 0 := 2
| 1 := 1
| (n+2) := if 2 ∣ n then 2 else min_fac_aux (n + 2) 3
Proofs with a custom relation
So we have a min_fac_aux
function. Great. But how do we know it does anything useful? We should prove, at the very least, that min_fac_aux n k
is a factor of n
.
As you might expect by now, a proof using this custom relation is not so different in structure from the construction of the function itself. We’ll extract the proof that the recursion is decreasing from our above definition and note it as a separate lemma.
open nat
lemma min_fac_lemma (n k : ℕ) (h : ¬ n < k * k) :
sqrt n - k < sqrt n + 2 - k :=
begin
rw nat.sub_lt_sub_right_iff,
{ exact lt_trans (lt_add_one _) (lt_add_one _) },
{ rw nat.le_sqrt, exact le_of_not_gt h },
end
Using this result, we complete our proof.
lemma min_fac_dvd (n : ℕ) : ∀ (k : ℕ), (min_fac_aux n k) ∣ n
| k := if h : n < k * k then by { rw min_fac_aux, simp [h] } else
if hk : k ∣ n then by { rw min_fac_aux, simp [h, hk] } else
have _ := min_fac_lemma n k h,
by { rw min_fac_aux, simp [h, hk], exact min_fac_dvd (k+2) }
using_well_founded { rel_tac := λ _ _,
`[exact ⟨_, measure_wf (λ k, sqrt n + 2 - k)⟩]}
Exercises
- Read about the Fermat factorisation method. Write a function
fermat_fac
to implement this method. Your function should rely on an auxiliary functionfermat_fac_aux
, defined via well-founded recursion. - Prove that
fermat_fac n
is a factor ofn
, for every natural numbern
.