Well-founded recursion

Abstract

Well-founded recursion generalises both strong induction and recursion. In this post we’ll:

  1. Discuss the (informal) notion of a well-founded relation.
  2. Show how to use well-founded relations (initially just with the < relation on ℕ) to define recursive functions and prove theorems
    • using the equation compiler and
    • by writing proof terms.
  3. Explain why rfl proofs don’t work for functions defined by well-founded recursion.
  4. Employ sneaky equation compiler tricks to simplify function writing.
  5. Explain how to use the using_well_founded command
    • to work with custom relations via rel_tac and
    • to create decreasing proofs via dec_tac.

In subsequent posts, I plan to (1) use well-founded recursion to show that the quicksort algorithm sorts a list and (2) discuss the formal definition of well_founded.

The idea of well-founded recursion

In our post on recursors, we saw that the standard way to define a function in Lean is via structural recursion. For instance, to define a function f on , we must define f 0 and, for each n : , we must specify how f (n+1) depends on f n.

Not all functions are most naturally defined this way. Consider a log-like function function \(\mathrm{lg} : \mathbb N \to \mathbb N\) defined so that

$$
\mathrm{lg}(n)=
\begin{cases}
0, & \text{ if } n = 0, \\
1 + \mathrm{lg}(n/2), & \text{if } 0 < n.
\end{cases}
$$

Thus, recalling that we round when perform division on \(\mathbb N\),
$$
\mathrm{lg}(7) = 1 + \mathrm{lg}(3) = 1 + (1 + \mathrm{lg}(1)) = 1 + (1 + (1 + \mathrm{lg}(0))) = 3.
$$

This computation terminates because it depends on computing \(\mathrm{lg}(n)\) for each \(n\) in the finite decreasing chain \(7 > 3 > 1 > 0\).

More generally, every decreasing chain of natural numbers is finite. We say that the \(<\) relation on \(\mathbb N\) is well founded.

This isn’t (syntactically) the same as the definition of well founded given in Lean. We’ll come to the Lean definition in a later post.
Moreover, as we’ll see in this post, the chain of equalities above cannot be proved by reflexivity.

Not all relations are well founded. In particular, the \(<\) relation on \(\mathbb Z\) is not well founded: you can easily find an infinitely long decreasing sequence of integers.

Lean has user-friendly mechanisms for:

  1. Proving that a given relation is well founded.
  2. Using well-founded relations to define functions.

As we explore these, make sure you have import tactic data.nat.parity tactic.induction at the top of your Lean file.

The equation compiler I

The equation compiler makes it (relatively) easy to write functions that depend on well-founded recursion.

So what exactly is the equation compiler? It’s the piece of software that takes declarations given using pattern matching, match expressions, etc. and compiles them down to Lean terms. For example, the function lg below is compiled to:

def lg._main._pack : Π (x : ℕ), (λ (x : ℕ), ℕ) x :=
λ (x : ℕ),
has_well_founded.wf.fix
(λ (x : ℕ),
x.cases_on
(id_rhs ((Π (y : ℕ), has_well_founded.r y 0 → ℕ) → ℕ)
  (λ (F : Π (y : ℕ), has_well_founded.r y 0 → ℕ), 0))
(λ (n : ℕ),
id_rhs ((Π (y : ℕ), has_well_founded.r y (n + 1) → ℕ) → ℕ)
  (λ (F : Π (y : ℕ), has_well_founded.r y (n + 1) → ℕ),
      have this : (n + 1) / 2 < n + 1, from _,
      1 + F ((n + 1) / 2) _))) x

You can see this first by typing #print lg, which shows:

def lg : ℕ → ℕ := lg._main

Likewise, #print lg._main shows

def lg._main : ℕ → ℕ := λ (ᾰ : ℕ), lg._main._pack ᾰ

And #print lg._main._pack gives the term at this top of this pop-out.

Here’s our function lg:

def lg : ℕ → ℕ
| 0 := 0
| (x + 1) :=
  have (x + 1) / 2 < (x + 1), from nat.div_lt_self' x 0,
    1 + lg ((x + 1)/2)

lg is defined so that lg 0 = 0 and lg (x + 1) = 1 + lg ((x + 1)/2) for every natural number x. But what’s happening with the have line?

We’re defining lg (x + 1) in terms of lg ((x + 1)/2. But this will only lead to a finite computation if \((x + 1)/2 < x + 1\). The have line provides Lean’s equation compiler with a proof of this fact.

If you remove the have line, Lean complains, ‘failed to prove recursive application is decreasing’. At the bottom of the error message, there’s some useful information:

nested exception message:
default_dec_tac failed
state:
lg : ℕ → ℕ,
x : ℕ
⊢ (x + 1) / 2 < x + 1

The last line shows what’s needed to prove prove that the recursive application is decreasing.

Now we can use the virtual machine to compute values of lg. For instance, #eval lg 32 gives 6 and #eval lg 270 gives 9. It looks like lg is related to the base-2 logarithm function. We’ll make this precise and prove part of the assertion later.

This was a simple example, but in other situations, we may need to give the equation compiler more information. For one thing, we may want to define a function on a type other than . In that case < (more precisely, nat.lt) will not be the correct relation. We’ll see later how to provide this information.

For now, let’s prove simple statements regarding lg. Our first goal is to prove lg 0 = 0. Were lg defined by structural recursion, this result could easily be proved by rfl, as in the example below:

def jane : ℕ → ℕ
| 0 := 2
| (x + 1) := 3 * jane x

example : jane 3 = 54 := rfl

But, for reasons that will be made clear in the next section, rfl doesn’t give a proof that lg 0 = 0. However, a rw will work.

lemma lg_zero : lg 0 = 0 := by { rw lg }

lemma lg_one : lg 1 = 1 := by { erw [lg, lg_zero] }

Above, erw is a more aggressive (more eager?) version of rw. We use it here so that Lean understands lg ((0 + 1) / 2) = lg 0.

Exercises

  1. Write a function ex1 such that
    $$
    \mathrm{ex1}(n)=
    \begin{cases}
    1 ,& \text{if } n = 0, \\
    n\times \mathrm{ex1}(n/3), & \text{otherwise}.
    \end{cases}
    $$
  2. Prove that ex1 0 = 1 and that ex1 4 = 4.

Well-founded recursion by hand

The equation compiler does a marvellous job of hiding the details of the construction. In this section, we’ll uncover some of these details. Feel free to skip to the next section if this is not of interest to you.

The function lg_by_hand is extensionally equal to lg.

def myF : Π (x : ℕ) (h : Π (y : ℕ), y < x → ℕ), ℕ
| 0 _ := 0
| (x + 1) h := 1 + h ((x + 1) / 2) (nat.div_lt_self' _ _)

def lg_by_hand := well_founded.fix nat.lt_wf myF

There are two parts to this definition. The actual computation is specified in myF. Before we go into the theory, compare this definition with that of lg above. The term h ((x + 1) / 2) (nat.div_lt_self' x 0) combines both the expression lg ((x + 1) / 2) and the inequality proved in the have line of lg.

The Lean function well_founded.fix takes 1) a proof hwf that a relation r is well founded (in this case, the theorem nat.lt_wf shows that < is well founded) and 2) a function F : Π x, (Π y, r y x → C y) → C x, where the motive C defines the type of the function being defined. It returns a function that takes each x to a term of type C x. In our case, myF takes the place of F. The motive is λ (a : ℕ), ℕ.

Crucially, the function returned by well_founded.fix (which we’ll henceforth abbreviate to fix) satisfies a ‘fixpoint equation’. Specifically,

fix hwf F x = F x (λ y h, fix hwf F y)

The theorem well_founded.fix_eq proves this assertion. In our case, we have:

lemma lg_by_hand_eq :
∀ x, lg_by_hand x = myF x (λ y h, lg_by_hand y) := well_founded.fix_eq _ _

This makes clear the sense in which myF x _ encodes the value of lg_by_hand x. We can use this theorem to compute lg_by_hand x for particular x.

lemma lg_zero_bh : lg_by_hand 0 = 0 := lg_by_hand_eq 0

lemma lg_one_bh : lg_by_hand 1 = 1 :=
by { erw [lg_by_hand_eq, myF, lg_zero_bh] }

Moreover, we can now see why rfl cannot prove lg 0 = 0 (or lg_by_hand 0 = 0). It’s because lg 0 isn’t definitionally equal to 0! What we can prove with rfl is the following:

example : lg_by_hand 0 = well_founded.fix nat.lt_wf myF 0 := rfl

Exercises

Repeat the exercises concerning ex1 from the previous section, but using well_founded.fix to define your function. Take care not to re-use ex1 in your definitions!

Underhanded tricks

Our aim in writing lg and is to give a function that behaves like the base-2 logarithm. But it’s out by 1. A better definition would be:

$$
\mathrm{lg_2}(n)=
\begin{cases}
0, & \text{ if } n \le 1, \\
1 + \mathrm{lg_2}(n/2), & \text{otherwise}.
\end{cases}
$$

In the definition of lg, the argument split nicely into two cases: it’s either 0 or n+1, for some natural number n. This matches the inductive definition of .

How do we use the equation compiler to give a definition of lg2? One option is to use three patterns:

def lg2 : ℕ → ℕ
| 0 := 0
| 1 := 0
| (n + 2) :=
  have h : (n + 2) / 2 = n / 2 + 1 :=
    nat.add_div_of_dvd_left (dvd_refl 2),
  have n / 2 + 1 < n + 2 := h ▸ nat.div_lt_self' _ _,
    1 + lg2 ((n + 2) / 2)

This would get tedious if the first case of our (mathematical) function \(\mathrm{lg_2}\) were \(\mathrm{lg_2}(n) = 0\) for \(n \le 100\). It would be impossible if the first case were \(\mathrm{lg_2}(n) = 0\) for \(\mathrm{even}(n)\).

Trick 1

Our first underhanded trick (not a technical term) is to use one pattern! Below, we use the pattern n to match any input.

def lg2' : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else
 have n / 2 < n, from
  nat.div_lt_self (by linarith) (nat.le_refl 2),
  1 + lg2' (n / 2)

Trick 2

Everything we’ve done so far has been in term mode. Nothing in tactic mode. However, it’s fine to use tactic mode as long as the desired inequality is proved in term mode.

Fortunately, Lean permits (nested) switching between modes. The following examples illustrate the principle, though there is little to be gained by mode-switching in these simple cases.

In lg2'', we prove the inequality in term mode then switch to tactic mode. The exact tactic returns us to term mode. We provide the same term as before.

def lg2'' : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else
 have n / 2 < n,
   from nat.div_lt_self (by linarith) (nat.le_refl 2),
 by { exact 1 + lg2'' (n / 2) }

By contrast, lg2''' starts (after an initial pattern match and if statement) in tactic mode. We immediately switch back to term mode via exact.

def lg2''' : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else
begin
  exact have n / 2 < n,
    from nat.div_lt_self (by linarith) (nat.le_refl 2),
  1 + lg2''' (n / 2),
end

Trick 3: dec_tac

In fact, we can separate out the proof of the using_well_founded command and the dec_tac field.

def lg2_iv : ℕ → ℕ
| n := if h : n ≤ 1 then 0 else 1 + lg2_iv (n / 2)
using_well_founded
 { dec_tac :=
`[exact nat.div_lt_self (by linarith) (nat.le_refl 2)] }

dec_tac is a field of the structure well_founded_tactics. It should be a tactic—more precisely, a term of type tactic.unit. The odd-looking backtick-bracket notation creates such a term from an ordinary ‘interactive mode’ tactic.

Via this construction, we can (if desired) push almost an entire proof or definition into tactic mode.

Exercises

  1. Using underhanded tricks, write a function ex2 given by
    $$
    \mathrm{ex2}(n) =
    \begin{cases}
    0, & \text{if } n \text{ is odd or } n = 0, \\
    1 + \mathrm{ex2}(n/2), & \text{otherwise}.
    \end{cases}
    $$
  2. Prove that ex2 0 = 0 and that ex2 4 = 2.
  3. Do the same, but using well_founded.fix to define your function (you can still use underhanded tricks, albeit indirectly).

Proofs by well-founded recursion

As you may have guessed by playing with lg, we expect:

$$\mathrm{lg(n)} = \lfloor\log_2(n)\rfloor + 1.$$

That is,

$$(n+1) < 2^{\mathrm{lg}(n+1)} \le 2(n+1),$$

for every natural number \(n\). In this section, we’ll prove the first of these inequalities. To start with, we’ll write the inequality as a predicate.

def lg_ineq (n : ℕ) : Prop := n + 1 < 2 ^ lg (n + 1)

Next we’ll give four proofs (!) of the desired result: one by hand and three using the equation compiler.

It transpires that we’ll need a couple of preliminary results:

lemma two_mul_succ_div_two {m : ℕ} : (2 * m + 1) / 2 = m :=
begin
  rw [nat.succ_div, if_neg], norm_num,
  rintros ⟨k, h⟩, exact nat.two_mul_ne_two_mul_add_one h.symm,
end

lemma two_mul_succ_succ {m : ℕ} : 2 * m + 1 + 1 = 2 * (m + 1) := by linarith

Proofs by hand

The ‘by hand’ proof uses well_founded.fix. The result and its proof is

lemma lg_lemma : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1) := well_founded.fix nat.lt_wf lg_lemma_aux

Here, we require:

lg_lemma_aux :
∀ (x : ℕ), (∀ (y : ℕ), y < x → lg_ineq y) → lg_ineq x

To prove this result, we can case split on x (considering whether we have a term of shape 0 or x + 1) and then decompose the second case to whether x is odd or even. Recall that h ‘carries’ with it proofs of the result for all y < x.

lemma lg_lemma_aux (x : ℕ) (h : Π (y : ℕ), y < x → lg_ineq y) : lg_ineq x :=
begin
  cases x,
  { rw [lg_ineq, lg_one], norm_num, }, -- base case
  dsimp [lg_ineq] at h ⊢,
  rcases nat.even_or_odd x with ⟨m, rfl⟩ | ⟨m, rfl⟩,
  { have h₄ : m < 2 * m + 1, by linarith,
    specialize h m h₄, rw [nat.succ_eq_add_one, lg, pow_add],
    rw two_mul_succ_succ, norm_num, exact h },
  { have h₄ : m < 2 * m + 1 + 1, by linarith,
    specialize h m h₄, rw [lg, pow_add],
    rw [two_mul_succ_succ, two_mul_succ_div_two], linarith }, 
end

Note the have h₄ terms (in green) above that are needed to justify that the recursive application is decreasing.

Equation compiler proof 1

Alternatively , we can use the equation compiler. Our first approach uses tactics nested inside with the decreasing proof outside.

lemma lg_lemma2 : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1)
| 0 := by { rw lg_one, norm_num, }
| (x + 1) := or.elim (nat.even_or_odd x)
( λ ⟨m, hm⟩,
  have m < x + 1, by linarith, -- needed for wf recursion
  begin
    specialize lg_lemma2 m, rw [hm, lg, pow_add],
    rw two_mul_succ_succ, norm_num, exact lg_lemma2,
  end )
( λ ⟨m, hm⟩,
  have m < x + 1, by linarith, -- needed for wf recursion
  begin
    specialize lg_lemma2 m, rw [hm, lg, pow_add],
    rw [two_mul_succ_succ, two_mul_succ_div_two], linarith,
  end )

If you aren’t fond of term mode, this proof may seem the most awkward as it uses or.elim and pattern-matching lambda abstraction.

Equation compiler proof 2

The next proof is a bit nicer. We put everything into tactic mode, pushout out to term mode only for the decreasing proof and the actual application of recursion.

lemma lg_lemma2' : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1)
| 0 := by { rw lg_one, norm_num, }
| (x + 1) :=
begin
  cases (nat.even_or_odd x),
  { rcases h with ⟨m, hm⟩,
    rw [hm, lg, pow_add],
    rw two_mul_succ_succ, norm_num,
    exact have m < x + 1, by linarith,
      lg_lemma2' m, },
  { rcases h with ⟨m, hm⟩,
    rw [hm, lg, pow_add],
    rw [two_mul_succ_succ, two_mul_succ_div_two],
    exact have m < x + 1, by linarith,
     show _, by { specialize lg_lemma2' m, linarith }, }
end

Note the weird construction on the penultimate line. We’re in term-mode here but push back into tactic mode using by.

Equation compiler proof 3

My favourite proof has (almost) everything in tactic mode. The decreasing proof is postponed to the final line via using_well_founded and dec_tac.

lemma lg_lemma2'' : ∀ (x : ℕ), x + 1 < 2 ^ lg (x + 1)
| 0 := by { rw lg_one, norm_num, }
| (x + 1) :=
begin
  cases (nat.even_or_odd x),
  { rcases h with ⟨m, hm⟩, specialize lg_lemma2'' m,
    rw [hm, lg, pow_add],
    rw two_mul_succ_succ, norm_num, exact lg_lemma2'', },
  { rcases h with ⟨m, hm⟩,
    specialize lg_lemma2'' m, rw [hm, lg, pow_add],
    rw [two_mul_succ_succ, two_mul_succ_div_two], linarith }
end
using_well_founded
{ dec_tac := `[exact show m < x + 1, by linarith] }

It’s worth observing that, in this case, the tactic provided to dec_tac supplants both decreasing proofs in each of the previous results.

Exercises

  1. Prove the corresponding upper bound result, viz. \(2^{\mathrm{lg}(n+1)} \le 2(n+1)\), for each \(n \in \mathbb N\). This should be easy.
  2. Determine and prove the correct lower bound result for the function lg2 defined previously.

Using other relations with rel_tac

The < relation on isn’t always the correct well-founded relation for a given application. Indeed, we may not be defining a function on !

A nice example is given in mathlib. For natural numbers n and k, min_fac_aux n k is an auxiliary function used in the definition of the minimal prime factor of n. It is a verified trial-division algorithm.

$$
m_f (n, k) =
\begin{cases}
n, & \text{if } n < k^2, \\
k, & \text{if } k \mid n, \\
m_f (n, k+2), & \text{otherwise}.
\end{cases}
$$

At first blush, it looks as though this defines a non-terminating function. The value of \(m_f(n, k)\) depends on the value of \(m_f(n, k + 2)\) but \(k + 2\) is greater than \(k\). Before we freak out, let’s try to calculate \(m_f(77, 3)\).

$$
m_f (77, 3) = m_f (77, 5) = m_f (77, 7) = 7.
$$

In this example, the ‘\(k\) value’ increased by 2 per step and the calculation terminated when we reached a value \(k\) for which $k \mid 77$. However, we cannot guarantee that such a \(k\) will be reached.
Here’s another example:

$$
m_f(11, 0) = m_f(11, 2) = m_f(11, 4) = 4
$$

The termination of the calculation above arises when we reach \(k\) such that \(n < k^2\).

Indeed, it’s clear that, no matter what the initial value of \(k\), by repeatedly adding \(2\) to \(k\), we will ultimately reach a situation where \(n < k^2\).

Using this observation, we’ll construct a well-founded relation.

Finding a well-founded relation

Our aim is to find a function \(f\) (which may depend on \(n\)) such that \(f(k +2) < f(k)\) for every \(k\). Intuitively, this shows that something is decreasing with every recursive application of \(m_f\).

This intuition is shored up by a theorem. Let f : α → ℕ be a function on α (called a measure in the context of well-founded relations). This function induces a relation on α defined so that a ≺ b means f a ≺ f b. It’s a theorem (called measure_wf in Lean) that any such relation will be well founded.

A first approach might be to try \(f(k) = 1 / k\). Surely \(1 / (k+2) < 1 / k\)? Actually no! As we’re dealing with natural number division, \(1 / k = 0\) for every non-zero \(k\). Though the induced relation is well-founded, it’s useless because we cannot prove the recursive application is decreasing.

What about \(-k\)? If we were working with integers and not natural numbers, this would work as \(-(k + 2) < -k\). Unfortunately, \(-k\) is \(0\) for every natural number \(k\).

However, all is not lost! Let’s take \(f(k) = \sqrt n – k + 2\). In our recursive application of \(m_f(n, k)\), we are guaranteed that \(n \ge k^2\). Thus, \(k \le \sqrt n\). From this, it follows that
$$
f(k + 2) = \sqrt n – k < \sqrt n + 2 – k = f(k).
$$

Putting this together, we have the following definition. Note the use of the ‘single pattern’ trick. We match every input argument with the single pattern k.

open nat

def min_fac_aux (n : ℕ) : ℕ → ℕ | k :=
if h : n < k * k then n else
if k ∣ n then k else
have sqrt n - k < sqrt n + 2 - k, -- needed for wf recursion
{ rw nat.sub_lt_sub_right_iff,
  { exact lt_trans (lt_add_one _) (lt_add_one _) },
  { rw nat.le_sqrt, exact le_of_not_gt h } },
min_fac_aux (k + 2)
using_well_founded {
  rel_tac := λ _ _,
    `[exact ⟨_, measure_wf (λ k, sqrt n + 2 - k)⟩]}

A new character, rel_tac, enters the stage. Here, as in most situations, we let Lean fill in the first two arguments by type inference. The last argument is a tactic whose purpose is to synthesize an instance of has_well_founded α. That is, it must (1) provide a relation r and (2) prove that r is well founded. As with dec_tac, we use the backtick-square bracket notation to produce a term of type tactic.unit.

In the example above, the tactic is:

exact ⟨_, measure_wf (λ k, sqrt n + 2 - k)⟩

We use a wildcard _ to let Lean fill in the relation. As intimated earlier,

measure_wf (λ k, sqrt n + 2 - k)

is a proof that the relation induced by the function λ k, sqrt n + 2 - k is well founded.

It’s worth noting that using_well_founded is always at work in the background. If the user doesn’t supply a dec_tac or rel_tac field, the command falls back on default tactics.

Note also that we can supply both dec_tac and rel_tac fields.

Finally, here is how mathlib defines the min_fac function using the auxiliary function above.

def min_fac : ℕ → ℕ
  | 0 := 2
  | 1 := 1
  | (n+2) := if 2 ∣ n then 2 else min_fac_aux (n + 2) 3

Proofs with a custom relation

So we have a min_fac_aux function. Great. But how do we know it does anything useful? We should prove, at the very least, that min_fac_aux n k is a factor of n.

As you might expect by now, a proof using this custom relation is not so different in structure from the construction of the function itself. We’ll extract the proof that the recursion is decreasing from our above definition and note it as a separate lemma.

open nat

lemma min_fac_lemma (n k : ℕ) (h : ¬ n < k * k) :
    sqrt n - k < sqrt n + 2 - k :=
begin
  rw nat.sub_lt_sub_right_iff,
  { exact lt_trans (lt_add_one _) (lt_add_one _) },
  { rw nat.le_sqrt, exact le_of_not_gt h },
end

Using this result, we complete our proof.

lemma min_fac_dvd (n : ℕ) : ∀ (k : ℕ), (min_fac_aux n k) ∣ n
| k := if h : n < k * k then by { rw min_fac_aux, simp [h] } else 
  if hk : k ∣ n then by { rw min_fac_aux, simp [h, hk] } else
  have _ := min_fac_lemma n k h, 
  by { rw min_fac_aux, simp [h, hk], exact min_fac_dvd (k+2) }
  using_well_founded { rel_tac := λ _ _,
    `[exact ⟨_, measure_wf (λ k, sqrt n + 2 - k)⟩]}

Exercises

  1. Read about the Fermat factorisation method. Write a function fermat_fac to implement this method. Your function should rely on an auxiliary function fermat_fac_aux, defined via well-founded recursion.
  2. Prove that fermat_fac n is a factor of n, for every natural number n.

Leave a Reply

Your email address will not be published. Required fields are marked *