使用mllib完成mnist手写识别任务

配置spark容器

使用docker-compose创建spark环境
yml配置文件:

version: '3.8'

services:
  spark-master:
    image: bde2020/spark-master:3.2.0-hadoop3.2
    container_name: spark-master
    ports:
      - "8080:8080"
      - "7077:7077"
      - "4040:4040"
    # 共享文件夹路径配置
    volumes:
      - /home/wmx/spark:/data
    environment:
      - INIT_DAEMON_STEP=setup_spark
  spark-worker-1:
    image: bde2020/spark-worker:3.2.0-hadoop3.2
    container_name: spark-worker-1
    depends_on:
      - spark-master
    ports:
      - "8081:8081"
    # 共享文件夹路径配置
    volumes:
      - /home/wmx/spark:/data
    environment:
      - "SPARK_MASTER=spark://spark-master:7077"
  spark-worker-2:
    image: bde2020/spark-worker:3.2.0-hadoop3.2
    container_name: spark-worker-2
    depends_on:
      - spark-master
    ports:
      - "8082:8081"
    # 共享文件夹路径配置
    volumes:
      - /home/wmx/spark:/data
    environment:
      - "SPARK_MASTER=spark://spark-master:7077"

可以用浏览器通过http://127.0.0.1:8080和http://127.0.0.1:4040来查看spark运行情况
可参考https://blog.csdn.net/Dragon_qing/article/details/124549698来配置

处理数据集

下载手写数字数据集

数据集在 http://yann.lecun.com/exdb/mnist/ 下载
数据集
下载这四个压缩文件

使用python来处理数据集

先将数据集解压出来,使用python转化为csv
python代码:

# mnist_to_csv.py
def convert(imgf, labelf, outf, n):
    f = open(imgf, "rb")
    o = open(outf, "w")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
    images = []

    for i in range(n):
        image = [ord(l.read(1))]
        for j in range(28 * 28):
            image.append(ord(f.read(1)))
        images.append(image)

    for image in images:
        o.write(",".join(str(pix) for pix in image) + "\n")
    f.close()
    o.close()
    l.close()


convert("train-images.idx3-ubyte", "train-labels.idx1-ubyte",
        "mnist_train.csv", 60000)
convert("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte",
        "mnist_test.csv", 10000)

再将csv转化为libsvm格式

# csv_to_libsvm.py
import csv


def execute(data, savepath):

    csv_reader = csv.reader(open(data))
    f = open(savepath, 'wb')
    for line in csv_reader:
        label = line[0]
        features = line[1:]
        libsvm_line = label + ' '

        for index, feature in enumerate(features):
            libsvm_line += str(index + 1) + ':' + feature + ' '
        f.write(bytes(libsvm_line.strip() + '\n', 'UTF-8'))

    f.close()


execute('mnist_train.csv', 'mnist_train.libsvm')
execute('mnist_test.csv', 'mnist_test.libsvm')

使用spark训练模型

进入spark容器

进入spark-master

sudo docker exec -it spark-master /bin/bash

打开spark-shell

spark-shell位于/spark/bin目录下
/spark/bin/spark-shell --master spark://spark-master:7077 --total-executor-cores 8 --executor-memory 2560m
把运行内存改高一点,否则可能内存不够用(6000m)
读取训练集

val train = spark.read.format("libsvm").load("/data/mnist_train.libsvm")

读取测试集

val test = spark.read.format("libsvm").load("/data/mnist_test.libsvm")

定义网络结构。

如果计算机性能不好可以降低隐藏层的参数。

val layers = Array[Int](784, 784, 784, 10)

导入多层感知机与多分类评价器。

import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

使用多层感知机初始化训练器。

val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100)

最后setMaxIter()是用来设置迭代次数的,可以改小一点来加快程序运行速度,不过可能会降低准确度。

训练模型

var model = trainer.fit(train)

输入测试集进行识别

val result = model.transform(test)

获取测试结果中的预测结果与实际结果

val predictionAndLabels = result.select("prediction", "label")

初始化评价器

val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")

计算识别精度

println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}")

在result上创建临时视图

result.toDF.createOrReplaceTempView("deep_learning")

使用Spark SQL的方式计算识别精度

spark.sql("select (select count(*) from deep_learning where label=prediction)/count(*) as accuracy from deep_learning").show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值