Constexpr derivatives of Rust functions, a concept

If several features regarding const stabilize then we can code compile time derivatives for primitive functions such as sin, cos, etc. See the sketch below. Keep in mind that these individual functions are represented as distinct types with no data, not as a function pointer. We truly can talk about the f32::sin function itself in the type system of Rust, although we aren't (yet?) allowed to name that type directly.

This is also the only minor issue that remains: we can't write Derivative as a trait of the functions itself, but rather of a simple wrapper around them. The only way to refer to them is via 'alias impl trait' (aka. existential types) but those rules forbid us from assuming that unnamable types are distinct types.

Somewhat exciting future prospect, regardless.

Proof of concept code

#![feature(type_alias_impl_trait)]
#![feature(inline_const)]
#![feature(const_trait_impl)]
#![feature(generic_const_exprs)]

fn main() {
    let f: Cos<false> = const {
        Cos::<false>::default()
            .derivative()
            .derivative()
            .derivative()
            .derivative()
    };
    
    assert_eq!(f(0.0f32), 1.0f32);
}

use core::ops::Deref;

trait Derivative: Deref
where
    <Self as Deref>::Target: Fn(f32) -> f32,
{
    type Dt;
    fn derivative(self) -> Self::Dt;
}

pub type SinIn = impl Fn(f32) -> f32;
const fn get_sin() -> SinIn { f32::sin }

pub type CosIn = impl Fn(f32) -> f32;
const fn get_cos() -> CosIn { f32::cos }

pub type ExpIn = impl Fn(f32) -> f32;
const fn get_exp() -> ExpIn { f32::exp }

struct Sin<const INVERTED: bool>(SinIn);
struct Cos<const INVERTED: bool>(CosIn);
struct Exp(ExpIn);

impl<const A: bool> const Deref for Sin<A> {
    // FAIL: using SinIn leads to a cycle, as it's defining for SinIn.
    type Target = dyn Fn(f32) -> f32;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<const A: bool> const Deref for Cos<A> {
    type Target = dyn Fn(f32) -> f32;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<const A: bool> const Deref for Exp {
    type Target = dyn Fn(f32) -> f32;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<const A: bool> const Derivative for Sin<A> {
    type Dt = Cos<A>;
    fn derivative(self) -> Cos<A> { core::mem::forget(self); Cos(get_cos()) }
}

impl<const A: bool> const Derivative for Cos<A>
    where [(); !A as usize]:
{
    type Dt = Sin<{!A}>;
    fn derivative(self) -> Sin<{!A}> { core::mem::forget(self); Sin(get_sin()) }
}

impl const Derivative for Exp {
    type Dt = Self;
    fn derivative(self) -> Self { self }
}


impl<const A: bool> const Default for Sin<A> {
    fn default() -> Self {
        Sin(get_sin())
    }
}

impl<const A: bool> const Default for Cos<A> {
    fn default() -> Self {
        Cos(get_cos())
    }
}

impl const Default for Exp {
    fn default() -> Self {
        Exp(get_exp())
    }
}
Published on