Java自学者论坛

 找回密码
 立即注册

手机号码,快捷登录

恭喜Java自学者论坛(https://www.javazxz.com)已经为数万Java学习者服务超过8年了!积累会员资料超过10000G+
成为本站VIP会员,下载本站10000G+会员资源,会员资料板块,购买链接:点击进入购买VIP会员

JAVA高级面试进阶训练营视频教程

Java架构师系统进阶VIP课程

分布式高可用全栈开发微服务教程Go语言视频零基础入门到精通Java架构师3期(课件+源码)
Java开发全终端实战租房项目视频教程SpringBoot2.X入门到高级使用教程大数据培训第六期全套视频教程深度学习(CNN RNN GAN)算法原理Java亿级流量电商系统视频教程
互联网架构师视频教程年薪50万Spark2.0从入门到精通年薪50万!人工智能学习路线教程年薪50万大数据入门到精通学习路线年薪50万机器学习入门到精通教程
仿小米商城类app和小程序视频教程深度学习数据分析基础到实战最新黑马javaEE2.1就业课程从 0到JVM实战高手教程MySQL入门到精通教程
查看: 808|回复: 0

【火炉炼AI】深度学习008-Keras解决多分类问题

[复制链接]
  • TA的每日心情
    奋斗
    2024-11-24 15:47
  • 签到天数: 804 天

    [LV.10]以坛为家III

    2053

    主题

    2111

    帖子

    72万

    积分

    管理员

    Rank: 9Rank: 9Rank: 9

    积分
    726782
    发表于 2021-8-29 10:52:42 | 显示全部楼层 |阅读模式

    【火炉炼AI】深度学习008-Keras解决多分类问题

    (本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2, Keras 2.1.6, Tensorflow 1.9.0)

    在我前面的文章【火炉炼AI】深度学习005-简单几行Keras代码解决二分类问题中,介绍了用Keras解决二分类问题。那么多分类问题该怎么解决?有哪些不同?


    1. 准备数据集

    为了演示,本次选用了博文keras系列︱图像多分类训练与利用bottleneck features进行微调(三)中提到的数据集,原始的数据集将所有类别的train照片放到train文件夹中,所有的test照片放在test文件夹中,而用不同数字开头来表示不同类别,比如以3开头的照片就是bus类等。首先将这些不同类别的照片放在不同的文件夹中,最终的train文件夹有5个子文件夹,每个子文件夹中有80张图片,最终的test文件夹中有5个子文件夹,每个子文件夹中有20张图片。总共只有500张图片。

    在代码上,需要用ImageDataGenerator来做数据增强,并且用flow_from_directory来从文件夹中产生数据流。

    代码和二分类的文章基本相同,此处就不贴出来了,可以去我的github直接看全部的代码。

    唯一的不同之处是要设置class_mode='categorical',而不是原来二分类问题的class_mode='binary'


    2. 模型的构建和训练

    基本和二分类一样,如下为模型的构建部分:

    # 4,建立Keras模型:模型的建立主要包括模型的搭建,模型的配置
    from keras.models import Sequential
    from keras.layers import Conv2D, MaxPooling2D
    from keras.layers import Activation, Dropout, Flatten, Dense
    from keras import optimizers
    def build_model(input_shape):
        # 模型的搭建:此处构建三个CNN层+2个全连接层的结构
        model = Sequential()
        model.add(Conv2D(32, (3, 3), input_shape=input_shape))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
    
        model.add(Conv2D(32, (3, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
    
        model.add(Conv2D(64, (3, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
    
        model.add(Flatten())
        model.add(Dense(64))
        model.add(Activation('relu'))
        model.add(Dropout(0.5)) # Dropout防止过拟合
        model.add(Dense(class_num)) # 此处多分类问题,用Dense(class_num)
        model.add(Activation('softmax')) #多分类问题用softmax作为activation function
        
        # 模型的配置
        model.compile(loss='categorical_crossentropy', # 定义模型的loss func,optimizer,
                      optimizer=optimizers.RMSprop(), # 使用默认的lr=0.001
                      metrics=['accuracy'])# 主要优化accuracy
    
        return model # 返回构建好的模型
    

    改变之处是:最后的Dense层需要用Dense(class_num)来代替Dense(1),然后用多分类的标配activation function: softmax。在模型的配置方面,也需要将loss function改为'categorical_crossentropy'。

    通过模型的训练后,最终结果如下所示:

    从结果上看:没有出现过拟合现象,但是test acc不太稳定,变化比较大。在平台期后的test acc约为0.85.

    ########################小**********结###############################

    1,多分类问题和二分类问题基本相同,不同之处在于:1,设置flow_flow_directory时要用设置class_mode='categorical'。2,模型的最后一层要用Dense(class_num)和softmax这个多分类专用激活函数。3,模型的loss function要使用categorical_crossentropy。

    #################################################################


    注:本部分代码已经全部上传到(我的github)上,欢迎下载。

    哎...今天够累的,签到来了1...
    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    QQ|手机版|小黑屋|Java自学者论坛 ( 声明:本站文章及资料整理自互联网,用于Java自学者交流学习使用,对资料版权不负任何法律责任,若有侵权请及时联系客服屏蔽删除 )

    GMT+8, 2025-1-5 09:45 , Processed in 0.059238 second(s), 28 queries .

    Powered by Discuz! X3.4

    Copyright © 2001-2021, Tencent Cloud.

    快速回复 返回顶部 返回列表