theano scan

参数

  • fn: fn 是函数描述了在一步 scan.fn 中所有的操作,这个个函数必须构造出描述一步迭代的输出的变量。同样还需要看成是 theano 的输入变量,表示输入序列的所有分片和过去的输出值,以及所有赋给 scan 的 non_sequences 的这些其他参数。而 scan 依照如下顺序传递给 fn 这些变量:
  • all time slices of the first sequence
  • all time slices of the second sequence
  • ...
  • all time slices of the last sequence
  • all time slices of the last sequence
  • all past slices of the first output
  • all past slices of the second otuput
  • ...
  • all past slices of the last output
  • all other arguments (the list given as non_sequences to scan)

序列的顺序和在列表 sequences 一致。输出的顺序和 output_info 的顺序相同。对任意序列或者输出,时间分片的顺序和他们给定作为 taps 的序列相同。例如,如果代码写成下面这样:

scan(fn, 
     sequences = [ dict(input= Sequence1, taps = [-3,2,-1]) , Sequence2 , dict(input = Sequence3, taps = 3) ] , 
     outputs_info = [ dict(initial = Output1, taps = [-3,-5]) , dict(initial = Output2, taps = None) , Output3 ] , 
     non_sequences = [ Argument1, Argument2])

fn 期待的输入顺序是:

  1. Sequence1[t-3]
  2. Sequence1[t+2]
  3. Sequence1[t-1]
  4. Sequence2[t]
  5. Sequence3[t+3]
  6. Output1[t-3]
  7. Output1[t-5]
  8. Output3[t-1]
  9. Argument1
  10. Argument2

non_sequences 同样包含共享变量,只是 scan 可以将这些变量忽略。为了代码的清晰,我们推荐将这些变量传递给 scan。在某种程度上,scan 可以确定其他即使没有传递给 scan (但被 fn使用的)non_sequences(not shared) 变量。例如:

import theano.tensor as TT
W = TT.matrix()
W_2 = W**2
def f(x): 
    return TT.dot(x,W_2)

函数需要返回两个东西。一个是按照 outputs_info 顺序排列的输出列表,不同的是对每个输出初始状态只有一个一个输出变量对应(即使没有使用 tap 值)。第二个 fn应当返回一个更新字典(告诉程序如何对共享变量进行每步的更新)。字典也可以以 tuple 的列表给出。对于这两个列表的顺序倒没有限制,fn 可以返回 (outputs_list, update_dictionary) 或者 (update_dictionary, outputs_list),或者就其中之一(另一个为空)。将 scan 当成是一个 while 循环,我们需要给 fn 增加一个退出循环的条件——将条件配置在一个 until 类中欧诺个。这些条件必须被返回为第三个元素,如下:

...
return [y1_t, y2_t], {x:x+1}, theano.scan_module.until(x<50)

注意,步数(最大迭代步骤数)仍然需要指定,即使一个有了一个条件

  • sequences: 序列是 Theano 变量或者字典的列表,告诉程序 scan 必须迭代的序列。如果序列以字典的形式给出,那么可选信息集合可给这个序列。字典应该包含如下的关键信息:
  • input(强制的)表示序列的 Theano 变量
  • taps fn 需要的时间片。通常以整数列表的方式提供,其中 k 的值表示一个迭代步骤 t scan 会 传递给 fn 序列片 t+k。默认值为 [0]

任何在 sequence 列表的 Theano 变量都会自动封装成一个字典,其 taps 被设置为 [0]

  • output_info outputs_info 是 Theano 变量或者字典的列表,给出了递归计算的输出初始时的状态。当这个初始状态给定为字典时,说明了对应这些初始状态的输出的可选信息。字典应当包含下面的元素:
  • initial 表示一个给定输出的初始状态的 Theano 变量。如果输出不是递归计算的(如 map)或者不需要初始状态,那么这里可以跳过。由 fn 前面时间步的输出,初始状态应该拥有和输出的同样形状,并且不能够包含输出的数据类型的转换。如果使用多时间 tap,初始状态应当由额外的维度来覆盖所有可能的 tap。例如,如果我们使用 -5, -2, -1 作为过去的 tap,在第 0 步,fn 会需要 output[-5], output[-2] output[-1]。这将由初始状态给出,这里的形状就是 (5,)+ output.shape。如果这个包含初始状态的变量称为 init_y 那么 init_y[0] 对应于 output[-5]init_y[1] 对应于 output[-4]init_y[2] 对应于 output[-3]init_y[3] 对应于 output[-2]init_y[4] 对应于 output[-1]。这个顺序可能看起来奇怪,不过这来自给定点的数组划分,也有相应的道理。假设我们有一个数组 x,选择 k 为时间步 0。那么初始的状态就是 x[:k],而输出就是 x[k:]。看看这个划分,在 x[:k] 中的元素顺序和 init_y 中完全一致。
  • taps 输出的时间 tap 将会被传递给 fn。他们是以负整数的列表给出,其中 k 表示在迭代步 t scan 会将切片 t+k 传递给 fn
    scan 会按照下面的规则进行:
  • 如果输出现在封装在一个字典中,scan 将会按照你仅仅在输出的最后一步使用他这个前提封装它(即让你的 tap 值设置为 [-1]
  • 如果你在一个字典中封装一个输出,并且你不提供任何的 tap 但是提供了一个初始状态,那么会假设你仅仅使用 tap 值为 -1.
  • 如果你将输出封装进一个字典中,不过你没有提供任何的初始状态,那么会假设你不回使用任何形式的 tap
  • 如果你提供 None 而非一变量或者一个空字典,那么 scan 假设你将不会对这个输出使用任何 tap(就像在 map 中那样)

如果 outputs_info 是一个空列表或者 Nonescan 假设了没有 tap 用在任何输出上。如果信息仅仅针对输出的子集给出,那么会抛出一个异常(因为并没有给出 scan 如何映射信息给 fn 的输出的默认行为)

  • non_sequences non_sequences 是在每一步被传递给 fn的参数的列表。我们可以可选择地将 fn 中使用的变量用此列表剔除,尽管为了代码清晰不建议这么做。
  • n_steps n_steps 是以 int 或者 Theano scalar 给出的迭代步数。如果任何输入序列没有足够的元素,scan 会给出一个错误。如果值为 0 输出将只有 0 行。如果值为负值,scan 会往回运行。如果 go_backwards flag 已经设置了,而且 n_steps 是负值,scan 将会向前运行。如果 n_steps 没有给出,scan 将在给定输入序列时就会搞清楚应当运行的步数。
  • truncate_gradient truncate_gradient 是用在 truncated BPTT 上的步数。如果你通过 scan op 来计算梯度,他们会使用 BPTT 来计算。通过给定不同于 -1 的值,你将确定使用 truncated BPTT 而非经典的 BPTT
  • go_backwards go_backwards 是表示 scan 是否往回走的标志。如果你将每个句子看做按照时间标记,让这个标志设置为 True 会让 scan 按照时间往回扫描。
  • name 在对 scan 进行性能分析时,给每个 scan 的实例进行命名是很重要的。性能分析器将产生整体的代码分析,甚至每个 scan 实例步骤的分析。实例的 name 则出现在这些分析中,提供了具有区分度的信息。
  • mode 推荐将此设置为 None,特别是对 scan 进行分析的时候(否则结果会不准确)。如果你倾向 scan 的某一步计算使用某种特殊的方式计算,可以使用 mode 来改变计算行为(参考 theano.function 来看看可能的使用方式)
  • profile Flag 或者 string。如果为 True,或者不同于一个空串,那么就会创建一个分析器对象,并绑定在 scan 的 inner 计算图上。如果 profile 设置为 True,该对象会有一个 scan 实例的名字,否则就使用传递的 string。分析器对象仅仅会在使用新 cvm链接器运行 inner 计算图的时候收集(并打印)信息(按照默认模式,对其他链接器,这个参数就是无用的)
  • allow_gc 设置此项可以允许 scan 的内部计算图进行 gc。如果为 None,就会使用 config.scan.allow_gc 的值。
  • strict 如果设置为 True,fn 中所有共享变量都必须作为 non_sequences或者 sequences 的一部分提供。

返回值

形为 (outputs, updates) 的元组,outputs 是 Theano 的变量或者 Theano 变量的列表,表示 scan 的输出(按照 outputs_info 的顺序)。updates 是一个字典的子类指定了所有共享变量的更新规则。这个字典应该被传递给 theano.function。不同于正常的字典的是我们验证这些 key 为 SharedVariable 并且确保这些字典的求和是一致的。

返回类型

元组(tuple)

推荐阅读更多精彩内容