MindSpore基础教程:LeNet-5 神经网络在MindSpore中的实现与训练

官方文档教程使用已经弃用的MindVision模块,本文是对官方文档的更新
深度学习在图像识别领域取得了显著的成功,LeNet-5 作为卷积神经网络的经典之作,在诸多研究和应用中占有重要地位。本文将详细介绍如何使用 MindSpore 框架实现并训练一个 LeNet-5 神经网络,专注于处理MNIST手写数字数据集。

前言

MindSpore 是华为推出的一种新型深度学习框架,旨在为用户提供高效、易用的编程体验。接下来,我们将通过实例来展示如何在 MindSpore 中构建、训练和评估一个经典的 LeNet-5 神经网络。

环境配置

MindSpore官网

LeNet-5 网络结构简介

LeNet-5 是一个简单的卷积神经网络,包含两个卷积层和三个全连接层。它经常被用于图像识别任务,特别是在处理像 MNIST 这样的手写数字数据集时表现出色。

数据集准备与预处理

首先,我们需要准备并预处理数据集。在这个例子中,我们将使用 MNIST 数据集。以下函数 create_dataset 负责加载数据集,并进行必要的预处理:

def create_dataset(data_path, batch_size=32, repeat_size=1):"""创建用于训练的MNIST数据集。此函数负责加载MNIST数据集,对数据进行预处理和转换,以便它们可以用于训练神经网络。数据预处理包括调整图像大小、重新缩放和类型转换。参数:data_path (str): MNIST数据集的路径。这应该是包含MNIST数据文件的目录路径。batch_size (int, 可选): 每个数据批次的大小。默认值为32。repeat_size (int, 可选): 数据集重复的次数。这用于增加数据集的大小。默认值为1。步骤:1. 加载MNIST数据集。2. 对图像执行大小调整操作,将图像大小统一调整为32x32像素。3. 对图像进行重新缩放和标准化处理。先将像素值缩放到0-1之间,然后进行标准化。4. 将图像的格式从高宽通道(HWC)转换为通道高宽(CHW)。5. 对标签进行类型转换,将其转换为整型(int32)。6. 对数据集进行洗牌、批处理和重复操作,以准备训练过程。返回:返回一个处理过的MNIST数据集,可以直接用于模型训练。注意:- 数据集的预处理步骤对于训练深度学习模型来说是非常重要的,它们会影响训练的效果和速度。- 调整batch_size和repeat_size可以影响模型训练时的内存消耗和速度。"""mnist_dataset = ds.MnistDataset(data_path)resize_operation = vision.Resize((32, 32), interpolation=Inter.LINEAR)rescale_normalization_op = vision.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)rescale_op = vision.Rescale(1.0 / 255.0, 0.0)hwc_to_chw_op = vision.HWC2CHW()type_cast_op = transforms.TypeCast(mstype.int32)mnist_dataset = mnist_dataset.map(input_columns="label", operations=type_cast_op)mnist_dataset = mnist_dataset.map(input_columns="image",operations=[resize_operation, rescale_op, rescale_normalization_op,hwc_to_chw_op])mnist_dataset = mnist_dataset.shuffle(buffer_size=10000)mnist_dataset = mnist_dataset.batch(batch_size, drop_remainder=True)mnist_dataset = mnist_dataset.repeat(repeat_size)return mnist_dataset

这个函数将数据集中的图像调整为统一的大小,并进行重新缩放和标准化。

构建 LeNet-5 模型

LeNet-5 模型的构建在 LeNet5 类中实现。此类定义了网络的各层及其排列:

class LeNet5(nn.Cell):"""LeNet-5 神经网络结构。这是一个经典的卷积神经网络,通常用于图像识别任务。它包含了两个卷积层和三个全连接层。参数:num_class (int): 输出层的类别数量。默认为10,适用于MNIST数据集。num_channel (int): 输入图像的通道数。对于灰度图像,此值为1。组件:- conv1: 第一个卷积层,使用有效填充。- conv2: 第二个卷积层,同样使用有效填充。- fc1: 第一个全连接层。- fc2: 第二个全连接层。- fc3: 第三个全连接层,输出层。- relu: 激活函数,使用ReLU。- max_pool2d: 最大池化层。- flatten: 扁平化层,用于全连接层之前的数据转换。方法:- construct(x): 定义了前向传播的过程。"""def __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):x = self.conv1(x)x = self.relu(x)x = self.max_pool2d(x)x = self.conv2(x)x = self.relu(x)x = self.max_pool2d(x)x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x

训练模型

接下来,我们定义 train_network 函数来训练模型。此函数接受模型实例、数据集路径和其他训练参数:

def train_network(model, epoch_size, data_path, repeat_size, checkpoint_callback):"""训练神经网络模型。此函数负责初始化数据集,然后使用指定的模型进行训练。在训练过程中,它将记录损失并保存模型的检查点。参数:model (Model): 要训练的神经网络模型。epoch_size (int): 训练过程中遍历数据集的次数。data_path (str): 训练数据集的路径。repeat_size (int): 数据集的重复次数,用于扩充数据集。checkpoint_callback (Callback): 用于保存模型检查点的回调函数。过程:- 使用 `create_dataset` 函数创建训练数据集。- 调用模型的 `train` 方法进行训练。- 在训练过程中,会通过回调函数记录损失和保存检查点。注意:- 确保提供的 `data_path` 包含适当格式的数据。"""print("============== 开始训练 ==============")ds_train = create_dataset(data_path, 32, repeat_size)model.train(epoch_size, ds_train, callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor()],dataset_sink_mode=False)print("============== 训练结束 ==============")

主函数

最后,我们通过 train 函数和 parse_arguments 函数将所有步骤串联起来。train 函数负责初始化模型、损失函数、优化器和检查点回调,然后调用 train_network 进行训练:

def train(args):"""初始化并训练LeNet-5神经网络模型。此函数设置了网络模型、损失函数、优化器,并定义了模型检查点。然后,使用指定的参数调用 `train_network` 函数来进行模型的训练。参数:args (Namespace): 一个包含训练参数的命名空间对象。此对象应该包含以下属性:- epochs (int): 模型训练的迭代次数。- data_url (str): 训练数据集的路径。- output_path (str): 保存模型检查点的路径。过程:1. 创建 LeNet-5 网络实例。2. 定义损失函数为 Softmax Cross-Entropy。3. 定义优化器为 Momentum 优化器。4. 创建模型实例,并指定网络、损失函数、优化器和评估指标。5. 设置模型检查点配置。6. 初始化模型检查点回调函数。7. 调用 `train_network` 函数进行训练。注意:- 确保 `args` 对象包含正确和完整的训练参数。- 调整优化器和损失函数的参数可以对训练结果产生影响。- 模型检查点将保存在 `args.output_path` 指定的路径中。"""net = LeNet5()net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)model = Model(net, net_loss, net_opt, metrics={"Accuracy": nn.Accuracy()})config_checkpoint = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)checkpoint_callback = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.output_path,config=config_checkpoint)train_network(model, args.epochs, args.data_url, 1, checkpoint_callback)

推理

# 加载网络param_dict = load_checkpoint("/root/MyCode/pycharm/lenet5/ckpt/checkpoint_lenet-19_1884.ckpt")network = LeNet5(num_class=NUM_CLASS, num_channel=1)# 用您定义的LeNet5类创建模型实例load_param_into_net(network, param_dict)# 将参数加载到网络中model = Model(network)def predict_digit(img):# 图像预处理img = cv2.resize(img, (32, 32))# 调整图像大小为32x32img = np.array(img, dtype=np.float32)# 转换图像数据类型img = (img - 0.1307) / 0.3081# 对图像进行标准化处理img = img[np.newaxis, np.newaxis, :, :]# 改变图像形状以符合网络输入要求(1, 1, 32, 32)# 将图像数据转换为MindSpore张量img_tensor = Tensor(img)# 使用模型进行预测output = model.predict(img_tensor)# 将输出转换为概率分布probabilities = Softmax()(output)# 获取每个类别的概率probabilities_np = probabilities.asnumpy()[0]# 将概率转换为字典格式labels = [str(i) for i in range(10)]# 类别标签,例如"0", "1", "2", ..., "9"probabilities_dict = {label: prob for label, prob in zip(labels, probabilities_np)}return probabilities_dictgr.Interface(fn=predict_digit,inputs=gr.Image(image_mode='L'),outputs=gr.Label(num_top_classes=NUM_CLASS),live=False,css=".footer {display:none !important}",title="0-9数字画板",description="画0-9数字",thumbnail="https://raw.githubusercontent.com/gradio-app/real-time-mnist/master/thumbnail2.png").launch()

结论

通过本文的指南,您可以在 MindSpore 框架中实现并训练一个经典的 LeNet-5 神经网络。LeNet-5 在图像识别任务中展现了卓越的性能,而 MindSpore 的高效和易用性使得深度学习研究和开发更加便捷。您可以根据本文的指导进行实验,并根据需要调整网络结构和训练参数。