Java自学者论坛

 找回密码
 立即注册

手机号码,快捷登录

恭喜Java自学者论坛(https://www.javazxz.com)已经为数万Java学习者服务超过8年了!积累会员资料超过10000G+
成为本站VIP会员,下载本站10000G+会员资源,会员资料板块,购买链接:点击进入购买VIP会员

JAVA高级面试进阶训练营视频教程

Java架构师系统进阶VIP课程

分布式高可用全栈开发微服务教程Go语言视频零基础入门到精通Java架构师3期(课件+源码)
Java开发全终端实战租房项目视频教程SpringBoot2.X入门到高级使用教程大数据培训第六期全套视频教程深度学习(CNN RNN GAN)算法原理Java亿级流量电商系统视频教程
互联网架构师视频教程年薪50万Spark2.0从入门到精通年薪50万!人工智能学习路线教程年薪50万大数据入门到精通学习路线年薪50万机器学习入门到精通教程
仿小米商城类app和小程序视频教程深度学习数据分析基础到实战最新黑马javaEE2.1就业课程从 0到JVM实战高手教程MySQL入门到精通教程
查看: 769|回复: 0

Python3读取深度学习CIFAR-10数据集出现的若干问题解决

[复制链接]
  • TA的每日心情
    奋斗
    7 天前
  • 签到天数: 745 天

    [LV.9]以坛为家II

    2041

    主题

    2099

    帖子

    70万

    积分

    管理员

    Rank: 9Rank: 9Rank: 9

    积分
    704660
    发表于 2021-7-2 05:39:10 | 显示全部楼层 |阅读模式

    今天在看网上的视频学习深度学习的时候,用到了CIFAR-10数据集。当我兴高采烈的运行代码时,却发现了一些错误:

    # -*- coding: utf-8 -*- import pickle as p import numpy as np import os def load_CIFAR_batch(filename): """ 载入cifar数据集的一个batch """ with open(filename, 'r') as f: datadict = p.load(f) X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float") Y = np.array(Y) return X, Y def load_CIFAR10(ROOT): """ 载入cifar全部数据 """ xs = [] ys = [] for b in range(1, 6): f = os.path.join(ROOT, 'data_batch_%d' % (b,)) X, Y = load_CIFAR_batch(f) xs.append(X) ys.append(Y) Xtr = np.concatenate(xs) Ytr = np.concatenate(ys) del X, Y Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) return Xtr, Ytr, Xte, Yte 
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

      错误代码如下:

    'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence
    • 1

      于是乎开始各种搜索问题,问大佬,网上的答案都是类似:

     

    这里写图片描述

     

      然而并没有解决问题!还是错误的!(我大概搜索了一下午吧,都是上面的答案)

      哇,就当我很绝望的时候,我终于发现了一个新奇的答案,抱着试一试的态度,尝试了一下:

    
    def load_CIFAR_batch(filename): """ 载入cifar数据集的一个batch """ with open(filename, 'rb') as f: datadict = p.load(f, encoding='latin1') X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float") Y = np.array(Y) return X, Y
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

      竟然成功了,这里没有报错了!欣喜之余,我就很好奇,encoding=’latin1’到底是啥玩意呢,以前没有见过啊?于是,我搜索了一下,了解到:

    Latin1是ISO-8859-1的别名,有些环境下写作Latin-1。ISO-8859-1编码是单字节编码,向下兼容ASCII,其编码范围是0x00-0xFF,0x00-0x7F之间完全和ASCII一致,0x80-0x9F之间是控制字符,0xA0-0xFF之间是文字符号。

    因为ISO-8859-1编码范围使用了单字节内的所有空间,在支持ISO-8859-1的系统中传输和存储其他任何编码的字节流都不会被抛弃。换言之,把其他任何编码的字节流当作ISO-8859-1编码看待都没有问题。这是个很重要的特性,MySQL数据库默认编码是Latin1就是利用了这个特性。ASCII编码是一个7位的容器,ISO-8859-1编码是一个8位的容器。

      还没等我高兴起来,运行后,又发现了一个问题:

    memory error
    • 1

      什么鬼?内存错误!哇,原来是数据大小的问题。

    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    • 1

      这告诉我们每批数据都是10000 * 3 * 32 * 32,相当于超过3000万个浮点数。 float数据类型实际上与float64相同,意味着每个数字大小占8个字节。这意味着每个批次占用至少240 MB。你加载6这些(5训练+ 1测试)在总产量接近1.4 GB的数据。

     for b in range(1, 2): f = os.path.join(ROOT, 'data_batch_%d' % (b,)) X, Y = load_CIFAR_batch(f) xs.append(X) ys.append(Y)
    • 1
    • 2
    • 3
    • 4
    • 5

      所以如有可能,如上代码所示只能一次运行一批。

      到此为止,错误基本搞定,下面贴出正确代码:

    # -*- coding: utf-8 -*- import pickle as p import numpy as np import os def load_CIFAR_batch(filename): """ 载入cifar数据集的一个batch """ with open(filename, 'rb') as f: datadict = p.load(f, encoding='latin1') X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float") Y = np.array(Y) return X, Y def load_CIFAR10(ROOT): """ 载入cifar全部数据 """ xs = [] ys = [] for b in range(1, 2): f = os.path.join(ROOT, 'data_batch_%d' % (b,)) X, Y = load_CIFAR_batch(f) xs.append(X) #将所有batch整合起来 ys.append(Y) Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3) Ytr = np.concatenate(ys) del X, Y Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) return Xtr, Ytr, Xte, Yte 
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    import numpy as np from julyedu.data_utils import load_CIFAR10 import matplotlib.pyplot as plt plt.rcParams['figure.figsize'] = (10.0, 8.0) plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' # 载入CIFAR-10数据集 cifar10_dir = 'julyedu/datasets/cifar-10-batches-py' X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) # 看看数据集中的一些样本:每个类别展示一些 print('Training data shape: ', X_train.shape) print('Training labels shape: ', y_train.shape) print('Test data shape: ', X_test.shape) print('Test labels shape: ', y_test.shape)

     顺便看一下CIFAR-10数据组成:

     

    CIFAR-10数据组成

     


    附件:CIFAR-10 python version下载地址

    CIFAR-10官网

    哎...今天够累的,签到来了1...
    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    QQ|手机版|小黑屋|Java自学者论坛 ( 声明:本站文章及资料整理自互联网,用于Java自学者交流学习使用,对资料版权不负任何法律责任,若有侵权请及时联系客服屏蔽删除 )

    GMT+8, 2024-3-29 08:12 , Processed in 0.077312 second(s), 29 queries .

    Powered by Discuz! X3.4

    Copyright © 2001-2021, Tencent Cloud.

    快速回复 返回顶部 返回列表