Java自学者论坛

 找回密码
 立即注册

手机号码,快捷登录

恭喜Java自学者论坛(https://www.javazxz.com)已经为数万Java学习者服务超过8年了!积累会员资料超过10000G+
成为本站VIP会员,下载本站10000G+会员资源,会员资料板块,购买链接:点击进入购买VIP会员

JAVA高级面试进阶训练营视频教程

Java架构师系统进阶VIP课程

分布式高可用全栈开发微服务教程Go语言视频零基础入门到精通Java架构师3期(课件+源码)
Java开发全终端实战租房项目视频教程SpringBoot2.X入门到高级使用教程大数据培训第六期全套视频教程深度学习(CNN RNN GAN)算法原理Java亿级流量电商系统视频教程
互联网架构师视频教程年薪50万Spark2.0从入门到精通年薪50万!人工智能学习路线教程年薪50万大数据入门到精通学习路线年薪50万机器学习入门到精通教程
仿小米商城类app和小程序视频教程深度学习数据分析基础到实战最新黑马javaEE2.1就业课程从 0到JVM实战高手教程MySQL入门到精通教程
查看: 912|回复: 0

神经网络解决二值分类问题的完整程序

[复制链接]
  • TA的每日心情
    奋斗
    2024-4-6 11:05
  • 签到天数: 748 天

    [LV.9]以坛为家II

    2034

    主题

    2092

    帖子

    70万

    积分

    管理员

    Rank: 9Rank: 9Rank: 9

    积分
    705612
    发表于 2021-6-6 14:15:48 | 显示全部楼层 |阅读模式
    import tensorflow as tf
    # NumPy是一个科学计算工具包,这里通过numpy工具包生成模拟数据集
    from numpy.random import RandomState
    
    
    # 定义一批训练数据batch大小
    batch_size = 8
    w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
    w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
    
    # 在shape的一个维度上使用None可以方便使用不同batch的大小。
    # 在训练时将数据分成比较小的batch,但在测试时,可以一次性使用所有数据。
    # 当数据比较小时这样比较方便测试数据,但当数据量大时,大的batch会导致内存溢出
    x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
    y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")
    
    # 定义前项传播的过程
    a = tf.matmul(x, w1)
    y = tf.matmul(a, w2)
    
    # 定义损失函数和和反向传播算法
    y = tf.sigmoid(y)
    # 定义损失函数,刻画预测值与真实值之间的差距
    cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)) +
                                    (1 - y_) * tf.log(tf.clip_by_value(1-y, 1e-10, 1.0)))
    
    train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
    
    # 通过随机数生成一个模拟数据集
    rdm = RandomState(1)
    dataset_size = 128
    X = rdm.rand(dataset_size, 2)
    
    # 定义规则来给出样本标签,在这里x1 + x2 < 1被认为正样本
    # 而其他为负样本。和TensorFlow游乐场不同,这里使用0代表负样本,1代表正
    # 大部分解决分类问题的神经网络都会采用0,1表示
    Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]
    
    # 定义会话
    with tf.Session() as sess:
        # 定义初始化变量
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        print(sess.run(w1))
        print(sess.run(w2))
    
        # 设定训练轮数
        STEPS = 10000
        for i in range(STEPS):
            # 每次选取batch_size个样本进行训练
            start = (i * batch_size) % dataset_size
            end = min(start + batch_size, dataset_size)
    
            # 通过选取的样本训练神经网络并更新参数
            sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
            if i % 1000 == 0:
                # 每隔一段时间计算所有数据上的交叉熵并输出
                total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
                print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy))
        print(sess.run(w1))
        print(sess.run(w2))
    
    
    
    """
    输出:
    [[-0.8113182   1.4845988   0.06532937]
     [-2.4427042   0.0992484   0.5912243 ]]
    [[-0.8113182 ]
     [ 1.4845988 ]
     [ 0.06532937]]
    After 0 training step(s), cross entropy on all data is 1.89805
    After 1000 training step(s), cross entropy on all data is 0.655075
    After 2000 training step(s), cross entropy on all data is 0.626172
    After 3000 training step(s), cross entropy on all data is 0.615096
    After 4000 training step(s), cross entropy on all data is 0.610309
    After 5000 training step(s), cross entropy on all data is 0.608679
    After 6000 training step(s), cross entropy on all data is 0.608231
    After 7000 training step(s), cross entropy on all data is 0.608114
    After 8000 training step(s), cross entropy on all data is 0.608088
    After 9000 training step(s), cross entropy on all data is 0.608081
    [[ 0.08782727  0.51795506  1.7529843 ]
     [-2.2372198  -0.20525953  1.0744455 ]]
    [[-0.49522772]
     [ 0.40552336]
     [-1.0061253 ]]
    
    """

     

    哎...今天够累的,签到来了1...
    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    QQ|手机版|小黑屋|Java自学者论坛 ( 声明:本站文章及资料整理自互联网,用于Java自学者交流学习使用,对资料版权不负任何法律责任,若有侵权请及时联系客服屏蔽删除 )

    GMT+8, 2024-4-20 22:05 , Processed in 0.072070 second(s), 29 queries .

    Powered by Discuz! X3.4

    Copyright © 2001-2021, Tencent Cloud.

    快速回复 返回顶部 返回列表