Scala 中的 case class 和 pattern matching

摘要

本文将详细介绍 Scala 中两个非常重要的概念 case class 和 pattern matching,并通过具体的案例来说明两者的具体用途。

1、从一个简单的例子开始

惯例,先贴代码,再详细说明

abstract class Expr

case class Var(name: String) extends Expr
case class Number(num: Double) extends Expr
case class UnOp(operator: String, arg: Expr) extends Expr
case class BinOp(operator: String, left: Expr, right: Expr) extends Expr

首先,我们要明确我们的需求:编写一个操作算数表达式的库
上述代码可以看成是基础的数据结构,包含一个抽象类 Expr,和四个子类:Var(变量)、Number(数字)、UnOp(一元运算)、BinOp(二元运算),且每个子类使用 case 关键字修饰

case class

简单的说,使用 case 修饰的 class 就是 case class,使用 case 会给我们带来很多方便的地方:
1、自动为该类添加了一个和类名一样的工厂方法,所以在实例化的时候我们可以直接写成 Var("x"),不再写成 new Var("x"),这种写法在嵌套的情况下非常简洁,例如:

val op = BinOp("+", Number(1.0), Var("x"))

2、所有 case class 的参数被隐式的添加上了 val 前缀,也就是说我们可以直接像这样访问参数

op.operator

3、编译器添加了 toString、hashCode、equals 方法的实现

println(op)  //BinOp(+,Number(1.0),Var(x))

4、编译器添加了一个 copy 方法,可以通过修改特定的参数生成一个新的 case class 实例

op.copy(operator = "-")  //BinOp = BinOp(-,Number(1.0),Var(x))

pattern matching

接下来我们要实现三种运算律:负负得正、任何数加 0 还是它本身、任何数乘以 1 还是它本身,即如下形式:

UnOp("-", UnOp("-", null)) => null // Double negation 
BinOp("+", null, Number(0)) => null // Adding zero 
BinOp("*", null, Number(1)) => null // Multiplying by one

具体的实现如下:

object Expr {
  def main(args: Array[String]): Unit = {
    val result = simplifyTop(UnOp("-", UnOp("-", Number(2.0))))
    println(result)  //Number(2.0)
  }
  
  def simplifyTop(expr: Expr): Expr = expr match {
    case UnOp("-", UnOp("-", e)) => e     // Double negation
    case BinOp("+", e, Number(0)) => e    // Adding zero
    case BinOp("*", e, Number(1)) => e    // Multiplying by one
    case _ => expr
  }
}

可以看出 Scala 中模式匹配(pattern matching)使用如下形式:

selector match { alternatives }

而 Java 中的 switch 形如:

switch (selector) { alternatives }

case 后面是匹配的某种情况,匹配的种类有很多,这里我们使用的是 Constructor patterns,=> 后面是匹配成功后执行的操作,但是不管执行什么操作返回的结果都必须是函数的返回值类型,这里是 Expr;最后一行 case _ 中的 _ 是通配符,表示匹配任何值,匹配成功后不执行任何操作,只是将 expr 返回

对比 Java 中的 switch,match 有三个不同点:第一、在 Scala 中 match 是一个表达式,这说明 match 是有返回值的;第二、只要匹配到一个,下面的语句就不会执行;第三、如果没有匹配到任何项就会剖出一个名为 MatchError 的异常,所以要确保能够匹配到值。

2、匹配的种类

Wildcard patterns

即通配符匹配 _ ,匹配所有的情况

Constant patterns

即字面量匹配,代码说明一切:

scala> def describe(x: Any): Any = x match {
     | case 5 => "five"
     | case true => "truth!"
     | case "spark" => "the future!"
     | case Nil => "the empty list"
     | case _ => "something else"
     | }
describe: (x: Any)Any

测试结果:

scala> describe(5)
res10: Any = five

scala> describe(true)
res11: Any = truth!

scala> describe("spark")
res12: Any = the future!

scala> describe(Nil)
res13: Any = the empty list

scala> describe(Array(1,2,3,4,5,6))
res14: Any = something else

scala> describe(1)
res15: Any = something else

Variable patterns

variable patterns 跟通配符(_)类似可以匹配任何对象,跟通配符不同的是 Scala 会将需要匹配的对象传递给 case 后面的变量,测试如下:

scala> 1 match {
     | case v => "the variable is " + v
     | }
res19: String = the variable is 1

scala> "scala" match {
     | case v => "the variable is " + v
     | }
res20: String = the variable is scala

scala> case class Person(name: String)
defined class Person

scala> Person("Jack") match {
     | case v => "the variable is " + v
     | }
res21: String = the variable is Person(Jack)

另外需要注意的是,上例中的匹配不需要 case _,否则会报错:

scala> 1 match {
     | case v => "the variable is " + v
     | case _ => "something else"
     | }

<console>:14: warning: unreachable code
       case _ => "something else"

因为永远不可能匹配到 _
如果要匹配一个变量的值,需要使用“`”符号,例如:

scala> val pi = math.Pi
pi: Double = 3.141592653589793

scala> 1 match {
     | case `pi` => "success"
     | case _ => "fail"
     | }
res24: String = fail

*如果不使用“”符号,进行的就是 variable pattern,使用的话就是 constant pattern** *“”符号的另外一个作用就是将关键字变成标识符

Constructor patterns

也就是最开始的例子中使用的匹配模式:case BinOp("+", e, Number(0)) => e,首先要检查这个对象是否为指定的 case class 的成员,然后检查构造方法的参数是否符合额外匹配(deep matches:意味着进行深层次的匹配),例如上例中的 "+" 是 constant pattern,e 是 variable 匹配,而 Number(0) 又是一个 constructor pattern...

Sequence patterns

和 constructor patterns 类似,但是可以指定任意数量的元素,例如匹配以 0 开头的总共有 3 个元素的 List

scala> List(0, 1, 2) match {
     | case List(0, _, _) => "success"
     | }
res26: String = success

匹配以 0 开头任意元素的数组:

scala> Array(0, 1, 2, 3, 4, 5, 6) match{
     | case Array(0, _*) => "success"
     | }
res27: String = success
Tuple patterns

匹配 Tuple,Tuple 中的元素类型可以不同

scala> (1, "scala", true) match {
     | case (a, b, c) => "matched " + a + b + c
     | }
res30: String = matched 1scalatrue
Typed patterns

即类型匹配,例如:

scala> def generalSize(x: Any) = x match{
     | case s: String => s.length
     | case m: Map[_, _] => m.size
     | case _ => -1
     | }
generalSize: (x: Any)Int

scala> generalSize("abc")
res31: Int = 3

scala> generalSize(Map(1 -> 'a', 2 -> 'b'))
res32: Int = 2

scala> generalSize(123)
res33: Int = -1

也可以写成如下形式:

if (x.isInstanceOf[String]) {
      val s = x.asInstanceOf[String]
      s.length
} else ...

其中 isInstanceOf 判断是否为某种类型,而 x.asInstanceOf[String] 将 x 转换成 String 类型
下面来看一个非常重要的概念:
类型擦除(Type erasure)
先看测试:

scala> def isIntIntMap(x: Any) = x match {
     |    case m: Map[Int, Int] => true
     |    case _ => false
     | }
<console>:12: warning: non-variable type argument Int in type pattern scala.collection.immutable.Map[Int,Int] (the underlying of Map[Int,Int]) is unchecked since it is eliminated by erasure
                  case m: Map[Int, Int] => true
                          ^
isIntIntMap: (x: Any)Boolean

warning 说的很清楚,在运行时范型会被擦出,所以不能判断 Map 的具体元素的类型是否匹配,但是 Array 除外,因为数组存储的时候将值和类型一起进行存储:

scala> def isStringArray(x: Any) = x match{
     | case a: Array[String] => "success"
     | case _ => "failed"
     | }
isStringArray: (x: Any)String

scala> isStringArray(Array("scala", "spark"))
res36: String = success

scala> isStringArray(Array(1, 2))
res37: String = failed
Variable binding

直接看代码:

scala> UnOp("abs", UnOp("abs", Number(1.0))) match {
     | case UnOp("abs", e @ UnOp("abs", _)) => e
     | case _ => "something else"
     | }
res40: java.io.Serializable = UnOp(abs,Number(1.0))

注意使用 @ 将 e 绑定到 UnOp("abs", _),所以最终的返回结果是 UnOp(abs,Number(1.0))

Pattern guard

类似于 for() 中可以使用 if 守卫,case 后面也可以使用,例如我们期望将 x + x 变成 x * 2,按照上面的例子我们会这样写

scala> def simplifyAdd(e: Expr) = e match {
     | case BinOp("+", x, x) => BinOp("*", x, Number(2))
     | }
<console>:17: error: x is already defined as value x
       case BinOp("+", x, x) => BinOp("*", x, Number(2))
                          ^

但是报错:x 已经被定义过了,因为同一个 x 不能出现两次,解决方法是使用两个变量 x 和 y,并判断 x 是否等于 y:

scala> def simplifyAdd(e: Expr) = e match {
     | case BinOp("+", x, y) if x==y => BinOp("*", x, Number(2))
     | case _ => "something else"
     | }
simplifyAdd: (e: Expr)java.io.Serializable

scala> simplifyAdd(BinOp("+", Number(1), Number(1)))
res1: java.io.Serializable = BinOp(*,Number(1.0),Number(2.0))

我们在 case 的后面加入了 if x==y,只有在满足此条件才能匹配成功执行后面的操作,我们也可以匹配一个以 a 开头的字符串:

scala> "apple" match{
     | case s: String if s(0) == 'a' => "the " + s + " is start with a "
     | case _ => "something else"
     | }
res3: String = "the apple is start with a "

3、匹配的顺序

接下来看一下 case 的书写顺序:

def simplifyAll(expr: Expr): Expr = expr match {
  case UnOp("-", UnOp("-", UnOp("-", e))) => simplifyAll(e)
  case BinOp("+", e, Number(0)) => simplifyAll(e)
  case BinOp("*", e, Number(1)) => simplifyAll(e)
  case UnOp(op, e) => UnOp(op, simplifyAll(e))
  case BinOp(op, l, r) => BinOp(op, simplifyAll(l), simplifyAll(r))
  case _ => expr
}

我们来看一下上面代码中第 1 个 case 和第 4 个 case 的顺序,因为第 1 个的匹配要比第 4 个的匹配更为严格,所以需要放在第 4 个的前面,否则永远不可能匹配到第 1 个的情况:

scala> def simplifyBad(expr: Expr): Expr = expr match {
     | case UnOp(op, e) => UnOp(op, simplifyBad(e))
     | case UnOp("-", UnOp("-", e)) => simplifyBad(e)
     | case _ => expr
     | }
<console>:16: warning: unreachable code
       case UnOp("-", UnOp("-", e)) => simplifyBad(e)

所以在书写 case 时需要将更为严格的匹配放在前面

4、Sealed Class

如果想确保进行模式匹配的时候不会漏掉一些情况,可以使用 sealed 关键字修饰 class 如下所示:

sealed abstract class Expr
case class Var(name: String) extends Expr
case class Number(num: Double) extends Expr
case class UnOp(op: String, arg: Expr) extends Expr
case class BinOp(op: String, left: Expr, right: Expr) extends Expr

然后我们写个函数测试一下:

scala> def describe(e: Expr): String = e match {
     |     case Number(_) => "a number"
     |     case Var(_)    => "a variable"
     | }

<console>:16: warning: match may not be exhaustive.
It would fail on the following inputs: BinOp(_, _, _), UnOp(_, _)
       def describe(e: Expr): String = e match {

可见交互式终端剖出一个 warning 提示 匹配不是完全的,比如 BinOp(_, _, _), UnOp(_, _) 就不能匹配到,由此可以看出如果使用 sealed 关键字修饰类,那么进行模式匹配的时候,编译器会检查匹配的是否全面。
但是有的时候我们根据上下文可以确定只有上面代码中写到的两种结果,这时可以在 e 的后面加上 : @unchecked 防止编译器对其进行检查

scala> def describe(e: Expr): String = (e: @unchecked) match {
     |     case Number(_) => "a number"
     |     case Var(_)    => "a variable"
     | }
describe: (e: Expr)String

可以看出这时就不会有 warning 了,这里的 @unchecked 是 Annotations(注释),这时我们不做展开说明

5、Option 类型

Option 代表一个可选值,有两种情况 Some(x) 代表有值,None 代表没有找到对应值,这里我们以 Map 为例进行说明

scala> val writeCode = Map("Spark" -> "Scala", "Hadoop" -> "Java")
writeCode: scala.collection.immutable.Map[String,String] = Map(Spark -> Scala, Hadoop -> Java)

scala> writeCode.get("Spark")
res1: Option[String] = Some(Scala)

scala> writeCode.get("Kafka")
res2: Option[String] = None

使用 get 获取 Spark 返回的是 Some(Scala) 说明 writeCode 中有这个值,而 Kafka 不再里面,所以返回的是 None

最常用的方法就是通过模式匹配来获得可选值,示例如下:

scala> def show(x: Option[String]) = x match {
     | case Some(s) => s
     | case None => "?"
     | }
show: (x: Option[String])String

scala> show(writeCode.get("Hadoop"))
res3: String = Java

scala> show(writeCode.get("Kafka"))
res4: String = ?

6、随处可见的模式匹配

可以一次定义多个变量

例如:

scala> val (number, string) = (1, "abc")
number: Int = 1
string: String = abc

常用来接收函数的返回值,例如 Spark 源码 SparkContext 中的如下部分

val (sched, ts) = SparkContext.createTaskScheduler(this, master)

同样可以这样使用:

scala> val BinOp(op, left, right) = BinOp("+", Number(1), Number(2))
op: String = +
left: Expr = Number(1.0)
right: Expr = Number(2.0)
一系列的 case 可以作为函数的一部分

例如:

scala> val withDefault: Option[Int] => Int = {
     | case Some(x) => x
     | case None => 0
     | }
withDefault: Option[Int] => Int = <function1>

scala> withDefault(Some(100))
res7: Int = 100

scala> withDefault(None)
res8: Int = 0

可以看出 case 后面的部分其实是作为函数 withDefault 的函数体,这种写法在消息通信中非常有用,例如 Spark 源码 Worker 中的如下部分:

override def receive: PartialFunction[Any, Unit] = synchronized {
    case SendHeartbeat =>
      if (connected) { sendToMaster(Heartbeat(workerId, self)) }

    case WorkDirCleanup =>
    ...

Worker 接收到消息后会判断是什么消息,然后针对每种消息进行具体的操作

for 语句中使用模式匹配

例如我们要遍历上面例子中的 Map 类型的 writeCode,代码如下:

scala> for( (architecture, writeCode) <- writeCode )
     | println("Architecture: " + architecture + " writeCode: " + writeCode)
Architecture: Spark writeCode: Scala
Architecture: Hadoop writeCode: Java

再来看另外一个例子:

scala> val results = List(Some("apple"), None, Some("orange"))
results: List[Option[String]] = List(Some(apple), None, Some(orange))

scala> for( Some(fruit) <- results) println(fruit)
apple
orange

总结

本文详细阐述了 Scala 中的 case class 和 pattern matching,使我们更加便捷的编写代码,但是 Scala 中的模式匹配远远不止这么简单,需要了解的可以研究 Scala 中的 Extractors。

本文参照:Programming in Scala, 3rd Edition 中的 Chapter 15:Case Class and Pattern Matching

推荐阅读更多精彩内容