记一次Kotlin tailrec,lambda 作为入参遇到的坑

起因是这周又轮到我组内分享,因为上次分享了 lambda 演算,这次就想接着说下,递归可能导致的问题(调用栈溢出),以及尾递归优化的方式。然后就提到了 TCO (Tail Call Optimization)CPS (Continuation-passing style)

我写了个阶乘函数来作为例子,一开始是用 js 写的 (组里会 js 的同学比较多)。

原味:

function fact(n) {
  if (n == 0) {
    return 1;
  }

  return n * fact(n - 1);
}

CPS 风味:

function cps_fact(n, k) {
  if (n == 0) {
    return k(1);
  }

  return cps_fact(n - 1, (x) => { return k(x) * n });
}

但当我执行的时候,发现我的 CPS 方法 还是爆栈了。


经过一番搜索后我发现 Node 在 6.2 版本的时候加入了 TCO 但是在 8.x 版本的时候去掉了,原因大概就是 TCO 会给调试带来一些困扰。

然后正好电脑上还装了 Kotlin 想到 Kotlin 中的函数也是一等公民,所以又用 Kotlin 写了一版。

fun cpsFact(a: Int, f: (x: Int) -> Int): Int {
    if (a == 0) {
        return f(1)
    }

    return cpsFact(a - 1) { x -> a * f(x) }
}

结果还是运行失败了,经过一番搜索发现 JVM 并没有做 TCO ,但是 Kotlin 提供了 tailrec 这个关键字可以实现 TCO。

tailrec fun cpsFact(a: Int, f: (x: Int) -> Int): Int {
    if (a == 0) {
        return f(1)
    }

    return cpsFact(a - 1) { x -> a * f(x) }
}

看到这个标题你就应该知道,这件事情没那么简单,还是爆栈了,但这次的报错信息有所不同,显示堆栈溢出的位置是一个匿名内部类。


一开始我也没多想,以为原因就是因为 Kotlin 并不是纯粹的函数式编程语言,传入函数的本质还是构造了一个匿名内部类,调用一下这个类的方法,只不过是个语法糖。然后我又用 Scheme 重写了一版,并且直接跑过了。Chibi-Scheme 0.8.0 版本还自带了一个参数可以禁用 TCO ,禁用之后直接报错。

好了,现在一切都看起来那么的完美。但等到分享的时候,海啸打断了我,说他之前也尝试用 Kotlin 实现过,参数不用 lambda 是可以跑过的。

我试着用定义一个接口的方式重写了一下,果然跑过了。代码如下:

interface Function {
    fun invoke(x: Int): Int
}

tailrec fun cpsFact(a: Int, f: Function): Int {
    if (a == 0) {
        return f.invoke(1)
    }

    return cpsFact(a - 1, object : Function {
        override fun invoke(x: Int): Int {
            return a * f.invoke(x)
        }
    })
}

这就很奇怪了,为了找到这两份代码的区别,我把两份代码都用 IDE 自带的工具 Decompile 了一下

接口:

Lambda:

经过一番对比终于找到了原因所在,内建的 Function1 的接口签名用到了泛型,而泛型在编译之后是会帮你生成一个桥接方法来做类型转换的。这两个方法之间出现了相互的调用,从而导致 TCO 失败。

/** A function that takes 1 argument. */
public interface Function1<in P1, out R> : Function<R> {
    /** Invokes the function with the specified argument. */
    public operator fun invoke(p1: P1): R
}

上面反编译的代码可能看起来没那么清楚,我们修改下上面自己实现接口的代码,加一个方法,让他们相互调用,也一样跑不过。

interface Function {
    fun invoke(x: Int): Int

    fun invokeWrapper(x: Int): Int
}

tailrec fun cpsFact(a: Int, f: Function): Int {
    if (a == 0) {
        return f.invokeWrapper(1)
    }

    return cpsFact(a - 1, object : Function {
        override fun invoke(x: Int): Int {
            return a * f.invokeWrapper(x)
        }

        override fun invokeWrapper(x: Int): Int {
            return invoke(x)
        }
    })
}

推荐阅读更多精彩内容