Pyro简介:产生式模型实现库(六),Pyro的张量尺寸

太长不看版

  • 模型在学习或调试过程中,设置pyro.enable_validation(True)
  • 张量的“广播”,维度对齐自右向左:torch.ones(3,4,5) + torch.ones(5)
  • 分布的尺寸 .sample().shape == batch_shape + event_shape
  • 分布的尺寸 .log_prob(x).shape == batch_shape(没有event_shape);
  • 使用expand()从Pyro中采样一批数据,或使用plate机制自动扩展;
  • 使用my_dist.to_event(1)声明维度为依赖(dependent),或说不独立;
  • 使用with pyro.plate('name', size):声明条件独立;
  • 所有维度要么是依赖的,要么是条件独立的;
  • 支持维度最左方的批处理,启动Pyro的并行处理;
    • 使用负号指标,如x.sum(-1),而不是x.sum(2)
    • 使用省略号,如pixel = image[...,i, j]
    • 如果要枚举i,j,使用Vindex,如pixel = Vindex(image)[...,i, j]

内容列表

  • 概率分布的形状
  • plate声明条件独立
  • 在plate中部分采样
  • 并行地枚举,张量的广播

文件头如下

import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True) #这句话最好加上

# 我们借助这个函数,检查模型是否正确
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

概率分布的尺寸:batch_shapeevent_shape

Pytorch的张量Tensor只有一个尺寸.shape,但是Distributions有两个尺寸.batch_shape.event_shape,分别表示条件独立的随机变量的大小和不独立的随机变量的大小。这两部分构成了一个样本的尺寸。

x = d.sample()
assert x.shape == d.batch_shape + d.event_shape

由于计算对数似然只牵涉不独立的变量,所以.log_prob()方法后,event_shape就被缩并了,只剩下batch_shape

assert d.log_prob(x) == d.batch_shape

Distributions.sample()方法可以输入一个参数sample_shape,作为独立同分布(iid)的随机变量,所以指定样本大小的采样,具有三个尺寸。

x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape

总结来说

      |      iid     | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape

由上可推论,单变量随机分布的event_shape为0,因为每次采样值是一个实数,所以没有不独立的维度。像MultivariateNormal多元高斯分布这样的概率分布,具有len(event_shape) == 1,因为每个采样是一个向量,向量内部是彼此依赖的(这里假定方差矩阵不是对角阵)。而InverseWishart逆威沙特分布具有len(event_shape) == 2,等等。

关于概率分布尺寸的举例

从单变量随机分布开始。

d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
# x是一个Pytorch张量,没有batch_shape和event_shape
assert x.shape == () 
assert d.log_prob(x).shape == ()

通过传入批参数,概率分布数据可以分成批。

d = Bernoulli(0.5 * torch.ones(3, 4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

另一种成批的方法,是通过expand()。不过只在参数的最左侧维度独立时才可使用。

d = Bernoulli(torch.tensor([.1, .2, .3, .4])).expand([3, 4])
# 注意expand的参数写在一个列表中
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

多元高斯分布具有非空的event_shape维度。对于这些分布来说,.sample().log_prob()的维度是不同的。

d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3, )
x = d.sample()
assert x.shape == (3, ) # == batch_shape + event_shape
assert d.log_prob(x).shape == () # == batch_shape

改变分布的维度独立性

使用关键字.to_event(n)改变不独立维度的情况,其中n表示从数第n维度开始,声明为不独立维度。

d = Bernoulli(0.5 * torch.ones(3, 4)).to_event(1)
assert d.batch_shape == (3, )
assert d.event_shape == (4, )
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, )

用户必须小心地设置.to_event(n)batch_shape缩减到合适的水平上,或者用pyro.plate声明维度的独立性。采样仍旧会保留batch_shape+event_shape的尺寸,然而log_prob(x)只剩下batch_shape

声明为不独立,通常是安全的做法

在Pyro中,我们常常会声明维度是不独立的,哪怕它们实际上是独立的。请看这个例子:

x = pyro.sample('x', dist.Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)

上面的例子很容易就可以换成MultivariateNormal分布。它将下面的写法简化了:

with pyro.plate('x_plate', 10):
    x = pyro.sample('x', dist.Normal(0, 1)) #不需要expand,系统自动补全
    assert x.shape == (10,)

实际上,这两份代码存在一点小小的差别。上面的代码中,Pyro默认x之间是不独立的,而下面的x则是条件独立的。声明为不独立通常是安全的,这与图论中的d-separation基于同一个原理:在不同节点之间连一条边,即便节点之间不存在互相依赖关系,随着优化该边的权重将越来越低,并不影响最终结果;而本就存在依赖的节点连了一条边,任优化策略多么高明,都无法弥补这一错误。这种错误常见于平均场假设的模型中。不过,在实际执行时,Pyro的SVI模块在估算Normal分布时,两份代码的梯度估计值是一样的。

通过plate声明维度为独立

Pyro的上下文管理器pyro.plate能够声明特定的维度为独立维度。推断算法可以利用这一独立性做一些算法优化,例如构造低方差的梯度估计器,再如求解推断问题不在指数空间而在线性空间采样。下面的例子中,我们将声明同一批次中的数据之间是互相独立的。
最简单的方法,是不声明独立维度,系统将缺省值-1——即最右边的维度,作为独立维度。

with pyro.plate('my_plate'):
    # 在该上下文中,维度-1将作为独立维度

虽然效果是一样的,不过我们仍提倡用户写出来,以帮助用户调试代码:

with pyro.plate('my_plate', len(data)):
    #  在该上下文中,维度-1将作为独立维度

从Pyro 0.2版本开始,plate语句可以嵌套使用。比如声明图像的每个像素都是独立的:

with pyro.plate('x_axis', 320):
    #  在该上下文中,维度-1将作为独立维度
    with pyro.plate('y_axis', 200):
        #  在该上下文中,维度-2和-1将作为独立维度

我们习惯上总从右向左声明独立维度,所以指标是负的,如-1,-2,等等。
有时情况会更复杂一些,比如我们希望声明一些噪声依赖x,另一些噪声依赖y,还有一些噪声依赖二者。这时Pyro允许用户声明多重独立,为了清楚地标明独立维度,必须指定dim这一参数,如下面的例子:

x_axis = pyro.plate('x_axis', dim = -2)
y_axis = pyro.plate('y_axis', dim = -3)
with x_axis:
    #  在该上下文中,维度-2将作为独立维度
with y_axis:
    #  在该上下文中,维度-3将作为独立维度
with x_axis, y_axis:
    #  在该上下文中,维度-2和-3将作为独立维度

让我们举更多例子,来展示plate的用法。

def model1():
    a = pyro.sample('a', Normal(0, 1))
    b = pyro.sample('b', Normal(torch.zeros(2), 1).to_event(1))
    with pyro.plate('c_plate', 2):
        c = pyro.sample('c', Normal(torch.zeros(2), 1))
    with pyro.plate('d_plate', 3):
        d = pyro.sample('d', Normal(torch.zeros(3, 4, 5), 1).to_event(2))
    assert a.shape == ()                  # batch_shape == (), event_shape == ()
    assert b.shape == (2,)                # batch_shape == (), event_shape == (2,)
    assert c.shape == (2,)                # batch_shape == (2,), event_shape == ()
    assert d.shape == (3, 4, 5)           # batch_shape == (3), event_shape == (4, 5)
    ##
    x_axis = pyro.plate('x_axis', 3, dim=-2)
    y_axis = pyro.plate('y_axis', 2, dim=-3)
    with x_axis:
        x = pyro.sample('x', Normal(0, 1))
    with y_axis:
        y = pyro.sample('y', Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample('xy', Normal(0, 1))
        z = pyro.sample('z', Normal(0, 1).expand([5]).to_event(1))
    assert x.shape == (3, 1)               # batch_shape == (3, 1), event_shape==()
    assert y.shape == (2, 1, 1)            # batch_shape == (2, 1, 1), event_shape==()
    assert xy.shape == (2, 3, 1)           # batch_shape == (2, 3, 1), event_shape==()
    assert z.shape == (2, 3, 1, 5)         # batch_shape == (2, 3, 1), event_shape==(5,)

test_model(model1, model1, Trace_ELBO())

可视化如下:

batch dims | event dims
-----------+-----------
           |        a = sample("a", Normal(0, 1))
           |2       b = sample("b", Normal(zeros(2), 1)
           |                        .to_event(1))
           |        with plate("c", 2):
          2|            c = sample("c", Normal(zeros(2), 1))
           |        with plate("d", 3):
          3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
           |                       .to_event(2))
           |
           |        x_axis = plate("x", 3, dim=-2)
           |        y_axis = plate("y", 2, dim=-3)
           |        with x_axis:
        3 1|            x = sample("x", Normal(0, 1))
           |        with y_axis:
      2 1 1|            y = sample("y", Normal(0, 1))
           |        with x_axis, y_axis:
      2 3 1|            xy = sample("xy", Normal(0, 1))
      2 3 1|5           z = sample("z", Normal(0, 1).expand([5])
           |                       .to_event(1))

为了在调试代码时方便地查看随机变量的形状,Pyro提供了Trace.format_shapes()
方法,在采样点上打印分布的形状(包含site['fn'].batch_shapesite['fn'].event_shape)、变量的形状(site['value'].shape)、如果计算对数似然概率时log_prob的形状(site['log_prob'].shape)。

trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()  #  可选的,这句话可以打印log_prob的形状
print(trace.format_shapes())

打印结果:

Trace Shapes:
 Param Sites:
Sample Sites:
       a dist       |
        value       |
     log_prob       |
       b dist       | 2
        value       | 2
     log_prob       |
 c_plate dist       |
        value     2 |
     log_prob       |
       c dist     2 |
        value     2 |
     log_prob     2 |
 d_plate dist       |
        value     3 |
     log_prob       |
       d dist     3 | 4 5
        value     3 | 4 5
     log_prob     3 |
  x_axis dist       |
        value     3 |
     log_prob       |
  y_axis dist       |
        value     2 |
     log_prob       |
       x dist   3 1 |
        value   3 1 |
     log_prob   3 1 |
       y dist 2 1 1 |
        value 2 1 1 |
     log_prob 2 1 1 |
      xy dist 2 3 1 |
        value 2 3 1 |
     log_prob 2 3 1 |
       z dist 2 3 1 | 5
        value 2 3 1 | 5
     log_prob 2 3 1 |

plate句块中采样部分张量

plate最重要的功能之一就是部分采样,plate句块中的随机变量都是条件独立的。如果样本量为总样本的一半,那么样本损失的值将被认为是总损失的一半。
在实现部分时,用户需要通知Pyro采样量和样本总量的值,Pyro就会随机产生一定量的数据指标作为样本。

data = torch.arange(100.)

def model2():
    mean = pyro.param('mean', torch.zeros(len(data)))
    with pyro.plate('data', len(data), subsample_size=10) as ind: 
        assert len(ind) == 10
        batch = data[ind]
        mean_batch = mean[ind]
        # 在batch中做一些计算
        x = pyro.sample('x', Normal(mean_batch, 1), obs=batch)
        assert x.shape == (10,)

test_model(model2, guide=lambda: None, loss=Trace_ELBO())

广播功能,实现数据的并行枚举

Pyro 0.2后的版本都支持离散随机变量的并行枚举功能。这一功能可以极大地减少计算变分推断时梯度估计的方差,确保优化的稳定性。
为了实现枚举,Pyro需要用户指定哪些维度是不独立的,哪些是独立的,只有不独立的维度才允许枚举。自然地,这一指定需要用到plate语句,我们需要声明最大数量的枚举范围,这一关键字为max_plate_nesting,它是SVI类的一个参数(而且通过TraceEnum_ELBO传入)。通常来说,Pyro可以自动地指定枚举范围(只要运行一次modelguide,系统将了解枚举范围),不过在动态变化的模型中,用户需要人工地指定max_plate_nesting的数值。
为了弄清楚max_plate_nesting的作用机制,我们重新回顾model1(),这一次我们关心三种维度的形状:最左边的枚举维度,中间的批维度,最右边的不独立维度。而max_plate_nesting规定了中间的批维度

      max_plate_nesting = 3
           |<--->|
enumeration|batch|event
-----------+-----+-----
           |. . .|      a = sample("a", Normal(0, 1))
           |. . .|2     b = sample("b", Normal(zeros(2), 1)
           |     |                      .to_event(1))
           |     |      with plate("c", 2):
           |. . 2|          c = sample("c", Normal(zeros(2), 1))
           |     |      with plate("d", 3):
           |. . 3|4 5       d = sample("d", Normal(zeros(3,4,5), 1)
           |     |                     .to_event(2))
           |     |
           |     |      x_axis = plate("x", 3, dim=-2)
           |     |      y_axis = plate("y", 2, dim=-3)
           |     |      with x_axis:
           |. 3 1|          x = sample("x", Normal(0, 1))
           |     |      with y_axis:
           |2 1 1|          y = sample("y", Normal(0, 1))
           |     |      with x_axis, y_axis:
           |2 3 1|          xy = sample("xy", Normal(0, 1))
           |2 3 1|5         z = sample("z", Normal(0, 1).expand([5]))
           |     |                     .to_event(1))

上面的例子中,如果我们声明(过度)充裕的max_plate_nesting=4也是可以的,但不能声明例如max_plate_nesting=2,因为2<3,这时系统将会报错。
我们再举一个例子:

@config_enumerate
#该修饰符表示枚举类型,不能省略!!
def model3():
    p = pyro.param('p', torch.arange(6) / 6.)
    locs = pyro.param('locs', torch.tensor([-1., 1.]))
    # locs in [-1, 1]
    # a in [0, 1, 2, 3, 4, 5]
    a = pyro.sample('a', Categorical(torch.ones(6) / 6.))
    # p[a] in [0, 1/6, 2/6, 3/6, 4/6, 5/6]
    b = pyro.sample('b', Bernoulli(p[a])) # 声明b依赖于a
    # b in [0, 1]
    with pyro.plate('c_plate', 4):
        c = pyro.sample('c',  Bernoulli(0.4))
        # c in [0, 1]
        with pyro.plate('d_plate', 5):
            d = pyro.sample('d', Bernoulli(0.3))
            # d in [0, 1]
            e_loc = locs[d.long()].unsqueeze(-1)
            # e_loc in [-1, 1]
            e_scale = torch.arange(1., 8.)
            # e_scale in [1, 2, ..., 7]
            e = pyro.sample('e', Normal(e_loc, e_scale).to_event(1)) # 依赖于d
    #                            枚举维度|批维度(独立维度)|不独立维度
    assert a.shape == (                6,            1,1            )  # 多类别分布的维度大小为6
    assert b.shape == (              2,1,            1,1            )  # 枚举伯努利分布,非扩增
    assert c.shape == (            2,1,1,            1,1            )  # 伯努利分布,非扩增
    assert d.shape == (          2,1,1,1,            1,1            )  # 伯努利分布,非扩增
    assert e.shape == (          2,1,1,1,            5,4,          7)  # e是采样出来的,依赖于d
    #
    assert e_loc.shape ==   (    2,1,1,1,            1,1,         1,) # 最后的逗号可以省略
    assert e_scale.shape == (                                     7,) # 注意逗号不能省略!!

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))

我们重新来可视化一下:

     max_plate_nesting = 2
            |<->|
enumeration batch event
------------|---|-----
           6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
         2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
            |   |     with pyro.plate("c_plate", 4):
       2 1 1|1 1|         c = pyro.sample("c", Bernoulli(0.3))
            |   |         with pyro.plate("d_plate", 5):
     2 1 1 1|1 1|             d = pyro.sample("d", Bernoulli(0.4))
     2 1 1 1|1 1|1            e_loc = locs[d.long()].unsqueeze(-1)
            |   |7            e_scale = torch.arange(1., 8.)
     2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
            |   |                             .to_event(1))

我们分析一下这些维度。我们为Pyro指定了枚举的维度max_plate_nesting:Pyro给a赋予枚举维度-3,给b赋予枚举维度-4,给c赋予枚举维度-5,给d赋予枚举维度-6。当用户不指定维度扩展后的数值时,新维度被默认为1,这方便计算。我们还可以观察到,log_prob的形状广播的范围是枚举维度和独立维度,比如trace.nodes['d']['log_prob'].shape == (2,1,1,1,5,4)

使用Pyro的自带工具Trace.format_shapes():

trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob() # 可选
print(trace.format_shapes())

结果:

Trace Shapes:                
 Param Sites:                
            p             6  
         locs             2  
Sample Sites:                
       a dist             |  
        value       6 1 1 |  
     log_prob       6 1 1 |  
       b dist       6 1 1 |  
        value     2 1 1 1 |  
     log_prob     2 6 1 1 |  
 c_plate dist             |  
        value           4 |  
     log_prob             |  
       c dist           4 |  
        value   2 1 1 1 1 |  
     log_prob   2 1 1 1 4 |  
 d_plate dist             |  
        value           5 |  
     log_prob             |  
       d dist         5 4 |  
        value 2 1 1 1 1 1 |  
     log_prob 2 1 1 1 5 4 |  
       e dist 2 1 1 1 5 4 | 7
        value 2 1 1 1 5 4 | 7
     log_prob 2 1 1 1 5 4 |  

编写并行代码

在Pyro中,我们需要掌握两个取巧的技术,来实现并行采样:广播椭圆分片。我们通过下面的例子来分别介绍枚举情形和非枚举情形下的用法。

width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumeration = None # 设为True或False

def fun(observe):
    p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)
    # 在这些样本点上,分布形状取决于Pyro是否枚举
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample('y_active', Bernoulli(p_y))
    if enumerated:
        assert x_active.shape == (2, 1, 1) # max_plate_nesting==2
        assert y_active.shape == (2, 1, 1, 1)
    else:
        assert x_active.shape == (width, 1)
        assert y_active.shape == (height, )
    # 第一个trick:广播,broadcast。枚举和非枚举都可使用。
    p = 0.1 + 0.5 * x_active * y_active
    if enumerated:
        assert p.shape == (2, 2, 1, 1)
    else:
        assert p.shape == (width, height)
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
    # 第二个trick:椭圆分片。Pyro可以在左方任意增加维度。
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)
    #
    with x_axis, y_axis:
        if observe:
            pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

def guide4():
    fun(observe=False)

# Test: 非枚举
enumerated = False
test_model(model4, guide4, Trace_ELBO())

# Test: 枚举。注意目标函数为TraceEnum_ELBO
enumerated = True
test_model(model4, config_enumerate(guide4, 'parallel'), TraceEnum_ELBO(max_plate_nesting=2))

在pyro.plate内部实现自动广播

在以上所有model/plate的实现中,我们都使用了pyro.plate的自动扩增功能,使变量满足pyro.sample规定的形状。这一广播方式等价于.expand()
我们稍许更改上面的代码作为例子,注意几点区别:

  • 我们仅考虑并行枚举的情况,但对于串行的、非枚举的情况也适用;
  • 我们将采样函数分离出来,model代码使用常规的形式,这样做有利于代码的维护;
  • pyro.plate使用ELBO的num_particles参数,将上下文中最远的内容打包。
# 规定采样的样本量
num_particals = 100
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])

def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x).expand([num_particals, width, 1]))
    with y_axis:
        y_active = pyro.sample('y_active', Bernoulli(p_y).expand([num_particals, 1, height]))
    return x_active, y_active

def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample('y_acitve', Bernoulli(p_y))
    return x_active, y_active

def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x).expand([width, 1]))
    with y_axis:
        y_active = pyro.sample('y_active', Bernoulli(p_y).expand([height]))
    return x_acitve, y_active

def fun(observe, sample_fn):
    p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)
    # 
    with pyro.plate('num_particals', 100, dim=-3):
        x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
        ## 并行枚举指标被扩增在“num_particals”最左边
        assert x_active.shape == (2, 1, 1, 1) 
        assert y_active.shape == (2, 1, 1, 1, 1)
        p = 0.1 + 0.5 * x_active * y_active
        assert p.shape == (2, 2, 1, 1, 1)
        dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
        for x, y in sparse_pixels:
            dense_pixels[..., x, y] = 1
        assert dense_pixels.shape == (2, 2, 1, width, height)
        #
        with x_axis, y_axis:
            if observe:
                pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)

def test_model_with_sample_fn(sample_fn):
    def model():
        fun(observe=True, sample_fn=sample_fn)
    #
    @config_enumerate
    def guide():
        fun(observe=False, sample_fn=sample_fn)

test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting) 
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)

在第一个采样函数中,我们像账房先生那样,仔细规定了Bernoulli分布的的形状。请仔细观察num_particles, widthheight传入sample_pixel_locations函数的方式。这一方式有些笨拙。
对于第二个采样函数,我们需要注意pyro.plate的参数必须要提供,这样系统才能猜出批维度的形状。
我们可以看到,对于张量操作,使用pyro.plate实现并行是多么容易!
pyro.plate还具有将代码模块化的效果。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,015评论 4 362
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,262评论 1 292
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,727评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,986评论 0 205
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,363评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,610评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,871评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,582评论 0 198
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,297评论 1 242
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,551评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,053评论 1 260
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,385评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,035评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,079评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,841评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,648评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,550评论 2 270

推荐阅读更多精彩内容