一起深度学习——CIFAR10

CIFAR10

  • 目的:
  • 实现步骤:
    • 1、导包:
    • 2、下载数据集
    • 3、整理数据集
    • 4、将验证集从原始的训练集中拆分出来
    • 5、数据增强
    • 6、加载数据集
    • 7、定义训练模型:
    • 8、定义训练函数:
    • 9、定义参数,开始训练:

目的:

实现从数据集中进行分类,一共有10个类别。

实现步骤:

1、导包:

import collections

import torch
from torch import nn
from d2l import torch as d2l
import shutil
import os
import math
import pandas as pd
import torchvision

2、下载数据集

#下载数据集
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
                                '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')

# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True

if demo:
    data_dir = d2l.download_extract('cifar10_tiny')
else:
    data_dir = '../data/kaggle/cifar-10/'

3、整理数据集

ef read_csv_labels (fname):
    with open(fname,'r') as f:
        lines = f.readlines()[1:] #[1:0]表示从第二行开始读取,因为第一行是行头

    # 按照逗号分割每一行, 且rstrip 去除每行末尾的换行符
    # eg: ["apple,orange,banana\n"] => [['apple','orange','banana']]
    tokens = [l.rstrip().split(',') for l in lines]
    
    return dict((name,label) for name, label in tokens)

labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))

4、将验证集从原始的训练集中拆分出来

def copyfile(filename,target_dir):
    os.makedirs(target_dir,exist_ok=True)
    shutil.copy(filename,target_dir)

# print("训练样本:",len(labels))
# print("类别:",len(set(labels.values())))
def reorg_train_valid(data_dir,labels,valid_ratio):
    # Counter :用于计算每个类别出现的次数
    # most_common() :用于统计返回出现次数最多的元素(类别,次数),是一个列表,并且按照次数降序的方式存储
    # 【-1】表示取出列表中的最后一个元组。
    # 【1】 表示取出该元组的次数。
    n = collections.Counter(labels.values()).most_common()[-1][1]
    # math.floor(): 向下取整  math.ceil(): 向上取整
    # 每个类别分配给验证集的最小个数
    n_valid_per_label = max(1,math.floor((n * valid_ratio)))

    label_count = {}
    for photo in os.listdir(os.path.join(data_dir,'train')):
        label = labels[photo.split('.')[0]] #取出该标签所对应的类别
        # print(train_file)
        fname = os.path.join(data_dir,'train',photo)
        copyfile(fname,os.path.join(data_dir,'train_valid_test','train_valid',label))

        #如果该类别没有在label_count中或者是 数量小于规定的最小值,则将其复制到验证集中
        if label not in label_count or label_count[label] < n_valid_per_label:
            copyfile(fname,os.path.join(data_dir,'train_valid_test','valid',label))
            label_count[label] = label_count.get(label,0) +1
        else:
            copyfile(fname,os.path.join(data_dir,'train_valid_test','train',label))
    return n_valid_per_label

def reorg_test(data_dir):
    for test_file in os.listdir(os.path.join(data_dir,'test')):
        copyfile(os.path.join(data_dir,'test',test_file), os.path.join(data_dir,'train_valid_test','test','unknown'))


def reorg_cifar10_data(data_dir,valid_ratio):
    labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))
    reorg_train_valid(data_dir,labels,valid_ratio)
    reorg_test(data_dir)

batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir,valid_ratio)
"""
结果会生成一个train_valid_test的文件夹,里面有:
- test文件夹---unknow文件夹:5张没有标签的测试照片
- train_valid文件夹---10个类被的文件夹:每个文件夹包含所属类别的全部照片
- train文件夹--10个类别的文件夹:每个文件夹下包含90%的照片用于训练
- valid文件夹--10个类别的文件夹:每个文件夹下包含10%的照片用于验证
"""

5、数据增强

transform_train = torchvision.transforms.Compose([
    # 原本图像是32*32,先放大成40*40, 在随机裁剪为32*32,实现训练数据的增强
    torchvision.transforms.Resize(40),
    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0), ratio=(1.0, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # 标准化图像的每个通道 : 消除评估结果中的随机性
    torchvision.transforms.Normalize(
        [0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]
    )
])

6、加载数据集

#加载数据集
train_ds,train_valid_ds = [
    torchvision.datasets.ImageFolder(os.path.join(data_dir,'train_valid_test',folder),transform=transform_train)
    for folder in ['train','train_valid']
]
valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder), transform=transform_test
    ) for folder in ['valid', 'test']
]
#定义迭代器
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_ds, train_valid_ds)
]

valid_iter = torch.utils.data.DataLoader(
    valid_ds,batch_size,shuffle=False,drop_last=True
)
test_iter = torch.utils.data.DataLoader(
    test_ds,batch_size,shuffle=False,drop_last=False
)


7、定义训练模型:

def get_net():
    num_classes = 10 #输出标签
    net = d2l.resnet18(num_classes,in_channels=3)
    return net

损失函数:

#损失函数
loss = nn.CrossEntropyLoss(reduction='none')


8、定义训练函数:

#定义训练函数
def train(net,train_iter,valid_iter,num_epochs,lr,wd,devices,lr_period,lr_decay):
    trainer = torch.optim.SGD(net.parameters(),lr=lr,momentum=0.9,weight_decay=wd)
    #学习率调度器:在经过lr_period个epoch之后,将学习率乘以lr_decay.
    scheduler = torch.optim.lr_scheduler.StepLR(trainer,lr_period,lr_decay)
    num_batches,timer = len(train_iter),d2l.Timer()
    legend = ['train_loss','train_acc']
    if valid_iter is not None:
        legend.append('valid_acc')
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=legend)
    net = nn.DataParallel(net,device_ids=devices).to(devices[0])

    for epoch in range(num_epochs):
        # 设置为训练模式
        net.train()
        metric = d2l.Accumulator(3)
        for i,(X,y) in enumerate(train_iter):
            timer.start()
            l,acc = d2l.train_batch_ch13(net,X,y,loss,trainer,devices)
            metric.add(l,acc,y.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5 ) ==0 or i == num_batches - 1:
                animator.add(epoch + (i + 1)/num_batches,(metric[0]/metric[2],metric[1]/metric[2],None))

        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net,valid_iter)
            animator.add(epoch+1,(None,None,valid_acc))
        scheduler.step()  #更新学习率
    measures = (f'train loss {metric[0] / metric[2]:.3f},'
                f'train acc{metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
                     f'example/sec on {str(devices)}')

9、定义参数,开始训练:


import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 100, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/607409.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

卷积通用模型的剪枝、蒸馏---蒸馏篇--RKD关系蒸馏(以deeplabv3+为例)

本文使用RKD实现对deeplabv3+模型的蒸馏;与上一篇KD蒸馏的方法有所不同,RKD是对展平层的特征做蒸馏,蒸馏的loss分为二阶的距离损失Distance-wise Loss和三阶的角度损失Angle-wise Loss。 一、RKD简介 RKD算法的核心是以教师模型的多个输出为结构单元,取代传统蒸馏学习中以教…

Leetcode—394. 字符串解码【中等】

2024每日刷题&#xff08;131&#xff09; Leetcode—394. 字符串解码 实现代码 class Solution { public:string decodeString(string s) {string curstr;int curNum 0;stack<pair<string, int>> st; for(char c: s) {if(isdigit(c)) {curNum curNum * 10 (c…

电脑中的两个固态硬盘比一个好,想知道为什么吗

你当前的电脑很有可能有一个NVME SSD作为主驱动器&#xff0c;但可能至少还有一个插槽可以放另一个SSD&#xff0c;而且这样做可能是个好主意。 两个SSD可以提高性能 如果你有两个固态硬盘&#xff0c;你可以从中获得比有一个更好的性能。一种方法是使用RAID 0将两个驱动器组…

Python_AI库 Pandas的loc和iloc的区别与使用实例

Python中Pandas的loc和iloc的区别与使用实例 在Pandas中&#xff0c;loc和iloc是两个常用的方法&#xff0c;用于基于标签&#xff08;label&#xff09;和整数位置&#xff08;integer location&#xff09;来选择数据。尽管两者在功能上有重叠&#xff0c;但它们在用法和性能…

OceanBase开发者大会实录:SaaS 场景降本50%!石基零售应用 OB Cloud 实践

本文来自2024 OceanBase开发者大会&#xff0c;石基零售助理总裁 、 ROC 产品事业部负责人陈亮的演讲实录—《石基零售与 OB Cloud 零售行业应用实践》。完整视频回看&#xff0c;请点击这里&#xff1e;> 大家下午好&#xff01;我是石基零售的陈亮。今天和大家分享一下石基…

struct和union大小计算规则

Union 一&#xff1a;联合类型的定义 联合也是一种特殊的自定义类型&#xff0c;这种类型定义的变量也包含一系列的成员&#xff0c;特征是这些成员公用同一块空间&#xff08;所以联合也叫共用体&#xff09; 比如&#xff1a;共用了 i 这个较大的空间 二&#xff1a; 联合的…

数据分析从入门到精通 2.pandas修真之前戏基础

从爱上自己那天起&#xff0c;人生才真正开始 —— 24.5.6 为什么学习pandas numpy已经可以帮助我们进行数据的处理了&#xff0c;那么学习pandas的目的是什么呢? numpy能够帮助我们处理的是数值型的数据&#xff0c;当然在数据分析中除了数值型的数据还有好多其他类型…

信通院智能体标准发布,实在智能牵头编写

4月28日&#xff0c;由人工智能关键技术和应用评测工业和信息化部重点实验室、中国信息通信研究院&#xff08;以下简称&#xff1a;中国信通院&#xff09;人工智能研究所共同主办的“人工智能”高质量发展研讨会顺利召开&#xff0c;会上中国信通院正式发布全国首个Agent&…

【C++】string类的使用②(容量接口Capacity || 元素获取Element access)

&#x1f525;个人主页&#xff1a; Forcible Bug Maker &#x1f525;专栏&#xff1a; STL || C 目录 前言&#x1f525;容量接口&#xff08;Capacity&#xff09;size和lengthcapacitymax_sizereserveresizeclearemptyshrink_to_fit &#x1f525;元素获取&#xff08;Ele…

从零开始打造个性化生鲜微信商城小程序

随着移动互联网的普及&#xff0c;小程序商城已经成为越来越多商家的选择。本文将通过实战案例分享&#xff0c;教您如何在五分钟内快速搭建个性化生鲜小程序商城。 步骤一&#xff1a;登录乔拓云网后台&#xff0c;进入商城管理页面 打开乔拓云官网&#xff0c;点击右上角的“…

【连连国际注册_登录安全分析报告】

连连国际注册/登录安全分析报告 前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨…

队列 (Queue)

今日励志语句&#xff1a;别总听悲伤的歌&#xff0c;别总想从前的事&#xff0c;别让过去拖住脚&#xff0c;别让未来被辜负。 前言&#xff1a;前面写了一篇 栈的实现&#xff0c;接下来学习一下它的"兄弟" 一、队列的概念&#xff1a; 队列&#xff1a; 也是数据…

C++类和对象(三) 缺省值 | static成员 | 内部类

前言&#xff1a; 这是关于类和对象的最后一篇文章&#xff0c;当然还是基础篇的最后一篇&#xff0c;因为类的三大特性继承&#xff0c;封装和多态都还没有讲&#xff0c;少年&#xff0c;慢慢来。 缺省值&#xff1a; 之前讲过&#xff0c;在C11的新标准中&#xff0c;支持为…

百面算法工程师 | 支持向量机面试相关问题——SVM

本文给大家带来的百面算法工程师是深度学习支持向量机的面试总结&#xff0c;文章内总结了常见的提问问题&#xff0c;旨在为广大学子模拟出更贴合实际的面试问答场景。在这篇文章中&#xff0c;我们还将介绍一些常见的深度学习算法工程师面试问题&#xff0c;并提供参考的回答…

哈夫曼编码python算法实现(图片版)

一、问题&#xff1a; 请使用哈夫曼编码方法对给定的字符串&#xff0c;进行编码&#xff0c;以满足发送的编码总长度最小&#xff0c;且方便译码。“AABBCCDDEEABCDDCDBAEEAAA” 二、过程&#xff1a; 三、结果&#xff1a;

软件1班20240509

文章目录 1.JDBC本质2.增3.改4.删5.查6.JDBC标准写法 1.JDBC本质 重写 接口的 方法 idea 报错 – 不动脑 alt enter 知道没有重写方法 CTRL o 重写 方法 快捷键 package com.yanyu;/*** Author yanyu666_508200729qq.com* Date 2024/5/9 14:42* description:*/ public interf…

网安学习路线终极指南!一步步带你从入门到精通,详尽技能点全解析!

目录 零基础小白&#xff0c;到就业&#xff01;入门到入土的网安学习路线&#xff01; 建议的学习顺序&#xff1a; 一、夯实一下基础&#xff0c;梳理和复习 二、HTML与JAVASCRIPT&#xff08;了解一下语法即可&#xff0c;要求不高&#xff09; 三、PHP入门 四、MYSQL…

(三十七)第 6 章 树和二叉树(二叉树的二叉链表存储表示实现)

1. 背景说明 2. 示例代码 1) errorRecord.h // 记录错误宏定义头文件#ifndef ERROR_RECORD_H #define ERROR_RECORD_H#include <stdio.h> #include <string.h> #include <stdint.h>// 从文件路径中提取文件名 #define FILE_NAME(X) strrchr(X, \\) ? st…

欧洲杯/奥运会-云直播

欧洲杯/奥运会要来了&#xff0c;如何升级自己的网站让你的顾客都能观赏直播已提高用户量呢&#xff1f;&#xff01; 【功能完善、平滑兼容】 云直播支持 RTMP 推流、 HLS 源站等多种直播源接入方式&#xff0c;提供直播 SDK&#xff0c;支持多终端适配&#xff0c;上行码率…

蓝桥杯省三爆改省二,省一到底做错了什么?

到底怎么个事 这届蓝桥杯选的软件测试赛道&#xff0c;都说选择大于努力,软件测试一不卷二不难。省赛结束&#xff0c;自己就感觉稳啦&#xff0c;全部都稳啦。没想到一出结果&#xff0c;省三&#xff0c;g了。说落差&#xff0c;是真的有一点&#xff0c;就感觉和自己预期的…
最新文章