使用Tensorflow XLA加速神经网络训练

前几天看了TensorFlow官方微信推的 利用 XLA 将 GPU 性能推向极限 一文,里面提到使用Tensorflow xla可以显著提升训练速度,在ResNet上速度最多可以提升3倍,而代码的改动是很少的。

XLA加速原理

按照google官方的说法,xla的加速原理主要是融合内核。看一个例子:

 def  model_fn(x,y,z):    
       return tf.reduce_sum(x + y * z) 

如果运行模型时不使用 XLA,图表会启动三个内核,分别用于乘法、加法和减法。

使用了XLA后,它会将加法、乘法和减法 “融合” 到单个 GPU 内核中。这种融合运算不会将 yz 和 x+yz 生成的中间值写入内存,而是将这些中间计算的结果直接 “流式传输” 给用户,并完整保存在 GPU 寄存器中。因为删除了内存运算,所以能够提升性能。

如何使用XLA

在官方提供的利用 XLA 将 GPU 性能推向极限一文中,提供了通过修改网络运算的方式来使用xla,这样针对的是单个运算,而且还必须先保证该运算可以被xla编译。这样我觉得不太方便,有没有可以自动将网络中所有可以被XLA编译的运算都改为使用XLA呢?翻了翻TensorFlow官方文档,的确是有的:

Turning on JIT compilation at the session level will result in all possible operators being greedily compiled into XLA computations. Each XLA computation will be compiled into one or more kernels for the underlying device. Subject to a few constraints, if there are two adjacent operators in the graph that both have XLA implementations, then they will be compiled into a single XLA computation. JIT compilation is turned on at the session level by setting the global_jit_level config to tf.OptimizerOptions.ON_1 and passing the config during session initialization.

里面提到了在session level配置jit可以把所有能被XLA编译的运算转化为XLA运算。这样就很方便了,只需要建立Session的时候配置好:

# Config to turn on JIT compilation
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

sess = tf.Session(config=config)

这样你就等于在使用xla加速了,十分方便!

推荐阅读更多精彩内容