# 动手学深度学习(二)——多层感知机(从零开始)

``````# 导入mxnet
import mxnet as mx

# 设置随机种子
mx.random.seed(2)

from mxnet import gluon
from mxnet import ndarray as nd
from mxnet import image
from utils import load_data_fashion_mnist, accuracy, evaluate_accuracy, SGD
``````

## 数据获取

``````# 批数据大小
batch_size = 256

# 获取训练数据和测试数据
``````

## 多层感知机

``````# 输入数据大小
num_inputs = 28 * 28
# 输出数据大小, 分为10类
num_outputs = 10

# 隐藏单元个数
num_hidden = 256

# 正态分布的标准差
weight_scale = 0.01

# 随机初始化输入层权重
W1 = nd.random_normal(shape=(num_inputs, num_hidden), scale=weight_scale)
b1 = nd.zeros(num_hidden)

# 随机初始化隐藏层权重
W2 = nd.random_normal(shape=(num_hidden, num_outputs), scale=weight_scale)
b2 = nd.zeros(num_outputs)

# 参数数组
params = [W1, b1, W2, b2]

# 需要计算梯度, 添加自动求导
for param in params:
``````

## 激活函数

``````# 激活函数使用ReLU, relu(x)=max(x,0)
def relu(X):
return nd.maximum(X, 0)
``````

## 定义模型

``````def net(X):
# 输入数据重排
X = X.reshape((-1, num_inputs))
# 计算激活值
h1 = relu(nd.dot(X, W1) + b1)
# 计算输出
output = nd.dot(h1, W2) + b2
return output
``````

## Softmax和交叉熵损失函数

``````# 定义交叉熵损失
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
``````

## 训练

``````# 定义迭代周期
epochs = 5

## 定义学习率
learning_rate = 0.1

# 训练
for epoch in range(epochs):
# 训练损失
train_loss = 0.0
# 训练集准确率
train_acc = 0.0
# 迭代训练
for data, label in train_data:
# 记录梯度
# 计算输出
output = net(data)
# 计算损失
loss = softmax_cross_entropy(output, label)
# 反向传播求梯度
loss.backward()
# 梯度下降
SGD(params, learning_rate/batch_size)
# 总的训练损失
train_loss += nd.mean(loss).asscalar()
# 总的训练准确率
train_acc += accuracy(output, label)

# 测试集的准确率
test_acc = evaluate_accuracy(test_data, net)

print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
``````
``````Epoch 0. Loss: 1.042064, Train acc 0.630976, Test acc 0.776142
Epoch 1. Loss: 0.601578, Train acc 0.788862, Test acc 0.815204
Epoch 2. Loss: 0.525148, Train acc 0.816556, Test acc 0.835136
Epoch 3. Loss: 0.486619, Train acc 0.829427, Test acc 0.833033
Epoch 4. Loss: 0.459395, Train acc 0.836104, Test acc 0.835136
``````