在PaddlePaddle中实现MNIST数据集训练:基础API

上一节《在PaddlePaddle中实现MNIST数据集训练:高层API》从MNIST数据集下载开始,详细介绍在PaddlePaddle中,基于高层API实现MNIST数据集训练。本节主要介绍在PaddlePaddle中,基于基础API实现MNIST数据集训练。高层API:Model.prepare()、Model.fit()、Model.evaluate()、Model.predict()都是由基础API封装而来,用基础API来实现模型创建与训练,就是用基础API来实现上述高层API的功能。

数据的载入与高层API实现部分一致,不同的是,需要用paddle.io.DataLoader类把paddle.io.Dataset类再封装一次,供基础API使用。

完整范例程序如下所示:

import gzip 
import struct 
import numpy as np 

# train-images-idx3-ubyte 文件格式, 参考:http://yann.lecun.com/exdb/mnist/
'''
[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 
0 means background (white), 255 means foreground (black).
'''
def load_images(image_file):
    # 读取*.gz格式文件
    with gzip.open(image_file) as f:
        buf = f.read()

    idx = 0
    # 读取文件信息
    magic, num_images, rows, cols = struct.unpack_from('>IIII', buf, idx)
    idx += struct.calcsize('>IIII')
    length = int(num_images*rows*cols)
    # 读取图像数据
    images = struct.unpack_from('>'+str(length)+'B', buf, idx)
    images = np.array(images).astype('float32')
    images = images.reshape(num_images, rows, cols)
    # 返回np.ndarray类型, N*r*c 图像数据
    return images


# train-labels-idx1-ubyte.gz 文件格式
'''
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.
'''
def load_labels(label_file):
    # 读取*.gz格式文件
    with gzip.open(label_file) as f:
        buf = f.read()
    # 读取文件信息
    idx = 0
    magic, num_labels = struct.unpack_from('>II', buf, idx)
    # 读取标签数据
    idx += struct.calcsize('>II')
    labels = struct.unpack_from('>'+str(num_labels)+'B',buf,idx)
 
    labels = np.array(labels).astype('int64')
    # 返回np.ndarray类型, 标签数据
    return labels

train_images = load_images('train-images-idx3-ubyte.gz')
test_images  = load_images('t10k-images-idx3-ubyte.gz')
train_labels = load_labels('train-labels-idx1-ubyte.gz').reshape(-1,1)
test_labels  = load_labels('t10k-labels-idx1-ubyte.gz').reshape(-1,1)

# 图像数据归一化
train_images = train_images / 255.0
test_images  = test_images / 255.0

num_train_samples = train_images.shape[0]
num_test_samples = test_images.shape[0]

import paddle
from paddle.io import Dataset
class TrainDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = train_images[index]
        label = train_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

class TestDataSet(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, num_samples):
        """
        步骤二:实现构造函数,定义数据集大小
        """
        super().__init__()
        self.num_samples = num_samples

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        data = test_images[index]
        label = test_labels[index]

        return data, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return self.num_samples

# 测试定义的数据集
train_dataset = TrainDataSet(num_train_samples)
test_dataset = TestDataSet(num_test_samples)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=100, shuffle=True)

# 定义模型
mnist = paddle.nn.Sequential(
    paddle.nn.Flatten(),
    paddle.nn.Linear(784, 512),
    paddle.nn.ReLU(),
    paddle.nn.Dropout(0.2),
    paddle.nn.Linear(512, 10)
)

# 设置模型为训练模式,这只会影响某些模块,如Dropout和BatchNorm
mnist.train()

# 模型训练相关配置,准备损失计算方法,优化器方法
loss_fn = paddle.nn.CrossEntropyLoss()
optim = paddle.optimizer.Adam(parameters=mnist.parameters())

# 设置迭代次数
epochs = 5
# 开始模型训练
for epoch in range(epochs):
    for batch_id, data in enumerate(train_loader()):

        x_data = data[0]            # 训练数据
        y_data = data[1]            # 训练数据标签
        predicts = mnist(x_data)    # 预测结果
        # print(x_data.shape, y_data.shape)
        # 计算损失 等价于 prepare 中loss的设置
        loss = loss_fn(predicts, y_data)

        # 计算准确率 等价于 prepare 中metrics的设置
        acc = paddle.metric.accuracy(predicts, y_data)

        # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中

        # 反向传播
        loss.backward()

        if (batch_id+1) % 100 == 0:
            print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id+1, loss.numpy(), acc.numpy()))

        # 更新参数
        optim.step()

        # 梯度清零
        optim.clear_grad()

# 用 evaluate 在测试集上对模型进行验证
mnist.eval()

for batch_id, data in enumerate(test_loader()):

    x_data = data[0]            # 测试数据
    y_data = data[1]            # 测试数据标签
    predicts = mnist(x_data)    # 预测结果

    # 计算损失与精度
    loss = loss_fn(predicts, y_data)
    acc = paddle.metric.accuracy(predicts, y_data)

    # 打印信息
    if (batch_id+1) % 100 == 0:
        print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id+1, loss.numpy(), acc.numpy()))

# 用 predict 在测试集上对模型进行测试
mnist.eval()
for batch_id, data in enumerate(test_loader()):
    x_data = data[0]
    predicts = mnist(x_data)
    # 获取预测结果
print(f"predict[0]:{predicts[0]}")

运行结果:
运行结果
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 158,425评论 4 361
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,058评论 1 291
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 108,186评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,848评论 0 204
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,249评论 3 286
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,554评论 1 216
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,830评论 2 312
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,536评论 0 197
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,239评论 1 241
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,505评论 2 244
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,004评论 1 258
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,346评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,999评论 3 235
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,060评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,821评论 0 194
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,574评论 2 271
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,480评论 2 267

推荐阅读更多精彩内容