๐Ÿฆ€ Functional Rust

333: Async Recursion

Difficulty: 4 Level: Expert Recursive async functions need `Box::pin` โ€” the future's size must be known at compile time.

The Problem This Solves

When you write a recursive function in Rust, the compiler calculates the function's stack frame size. For a plain function that's fine. But an `async fn` is secretly transformed into a state machine struct โ€” and a recursive async function would produce a struct that contains itself, making its size infinite. The compiler refuses this outright. The fix is heap-allocation: instead of returning `impl Future<Output=T>` (unknown/infinite size), you return `Pin<Box<dyn Future<Output=T>>>` โ€” a fat pointer of known size that points to the heap-allocated future. The `async-recursion` crate (`#[async_recursion]`) generates this boilerplate automatically. Without the crate, you write `Box::pin(async move { ... })` by hand. This pattern shows up any time you traverse a recursive data structure asynchronously: JSON tree parsing, directory scanning, graph traversal, or โ€” as in this example โ€” computing properties of a binary tree.

The Intuition

In Python asyncio:
async def async_sum(tree):
 if tree is None: return 0
 return tree.value + await async_sum(tree.left) + await async_sum(tree.right)
Python hides the complexity โ€” every coroutine is already heap-allocated. Rust exposes it because it normally stores futures on the stack for performance. The `Box::pin` is you saying "ok, put this one on the heap." Think of `Pin<Box<dyn Future>>` as Rust's equivalent of a manually heap-allocated coroutine.

How It Works in Rust

// Type alias for brevity โ€” a heap-pinned future with lifetime 'a
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;

fn async_sum(t: &Tree) -> BoxFuture<'_, i64> {
 // Box::pin(...) wraps the async block and pins it to the heap
 Box::pin(async move {
     match t {
         Tree::Leaf => 0,
         Tree::Node { value, left, right } =>
             // .await works normally inside the box
             *value as i64 + async_sum(left).await + async_sum(right).await,
     }
 })
}
Why `async move`? The closure captures `t` by value (actually by reference, bounded by `'_`). The `move` ensures ownership semantics are explicit even when borrowing. The minimal `block_on` executor at the bottom drives the futures to completion โ€” in a real project you'd use `tokio::main` or `tokio::runtime::Runtime::block_on`.

What This Unlocks

Key Differences

ConceptOCamlRust
Recursive asyncDirect recursion with Lwt (`let rec f x = ... >>= f y`)Must use `Box::pin(async { ... })`
Stack vs heapLwt always heap-allocates continuationsRust normally stack-allocates; `Box` opts into heap
Return type`'a Lwt.t` (always a pointer)`Pin<Box<dyn Future<Output=T>>>` (explicit)
`async-recursion` crateN/A`#[async_recursion]` generates Box::pin automatically
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;

#[derive(Debug)]
enum Tree { Leaf, Node { value: i32, left: Box<Tree>, right: Box<Tree> } }

impl Tree {
    fn leaf() -> Box<Self> { Box::new(Self::Leaf) }
    fn node(v: i32, l: Box<Self>, r: Box<Self>) -> Box<Self> { Box::new(Self::Node{value:v,left:l,right:r}) }
}

fn async_sum(t: &Tree) -> BoxFuture<'_, i64> {
    Box::pin(async move {
        match t {
            Tree::Leaf => 0,
            Tree::Node{value,left,right} => *value as i64 + async_sum(left).await + async_sum(right).await,
        }
    })
}

fn async_depth(t: &Tree) -> BoxFuture<'_, usize> {
    Box::pin(async move {
        match t {
            Tree::Leaf => 0,
            Tree::Node{left,right,..} => 1 + async_depth(left).await.max(async_depth(right).await),
        }
    })
}

fn block_on<F: Future>(fut: F) -> F::Output {
    use std::task::{RawWaker,RawWakerVTable,Waker};
    unsafe fn cl(p: *const())->RawWaker{RawWaker::new(p,&V)} unsafe fn n(_:*const()){}
    static V: RawWakerVTable = RawWakerVTable::new(cl,n,n,n);
    let w = unsafe{Waker::from_raw(RawWaker::new(std::ptr::null(),&V))};
    let mut cx = Context::from_waker(&w);
    let mut f = Box::pin(fut);
    loop { if let Poll::Ready(v) = f.as_mut().poll(&mut cx) { return v; } }
}

fn sample() -> Box<Tree> {
    Tree::node(1, Tree::node(2,Tree::node(4,Tree::leaf(),Tree::leaf()),Tree::node(5,Tree::leaf(),Tree::leaf())), Tree::node(3,Tree::node(6,Tree::leaf(),Tree::leaf()),Tree::leaf()))
}

fn main() {
    let t = sample();
    println!("Sum: {}", block_on(async_sum(&t)));
    println!("Depth: {}", block_on(async_depth(&t)));
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test] fn leaf_sum_zero() { assert_eq!(block_on(async_sum(&Tree::Leaf)), 0); }
    #[test] fn tree_sum() { assert_eq!(block_on(async_sum(&sample())), 21); }
    #[test] fn tree_depth() { assert_eq!(block_on(async_depth(&sample())), 3); }
}
(* OCaml: recursive tree computation *)

type tree = Leaf | Node of int * tree * tree

let rec sum = function
  | Leaf -> 0
  | Node (v,l,r) -> v + sum l + sum r

let rec depth = function
  | Leaf -> 0
  | Node (_,l,r) -> 1 + max (depth l) (depth r)

let t = Node(1, Node(2, Node(4,Leaf,Leaf), Node(5,Leaf,Leaf)), Node(3, Node(6,Leaf,Leaf), Leaf))

let () = Printf.printf "Sum: %d, Depth: %d\n" (sum t) (depth t)