(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别

文章目录

      • 1. 导入相关库
      • 2. 加载数据集
      • 3. 整理数据集
      • 4. 图像增广
      • 5. 读取数据
      • 6. 微调预训练模型
      • 7. 定义损失函数和评价损失函数
      • 9. 训练模型

1. 导入相关库

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

2. 加载数据集

- 该数据集是完整数据集的小规模样本
# 下载数据集
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',
                            '0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')

# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:
    data_dir = d2l.download_extract('dog_tiny')
else:
    data_dir = os.path.join('..', 'data', 'dog-breed-identification')

3. 整理数据集

def reorg_dog_data(data_dir, valid_ratio):
    labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))
    d2l.reorg_train_valid(data_dir, labels, valid_ratio)
    d2l.reorg_test(data_dir)

batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

4. 图像增广

transform_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    )
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    )
])

5. 读取数据

train_ds, train_valid_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),
        transform=transform_train
    ) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [
    torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train_valid_test', folder),
        transform=transform_test
    ) for folder in ['valid', 'test']
]
train_iter, train_valid_iter = [
    torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True, drop_last=True
    ) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(
    valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(
    test_ds, batch_size, shuffle=False, drop_last=True
)

6. 微调预训练模型

def get_net(devices):
    finetune_net = nn.Sequential()
    finetune_net.features = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
    # 定义一个新的输出网络,共有120个输出类别
    finetune_net.output_new = nn.Sequential(
        nn.Linear(1000, 256),
        nn.ReLU(),
        nn.Linear(256, 120)
    )
    finetune_net = finetune_net.to(devices[0])
    # 冻结参数
    for param in finetune_net.features.parameters():
        param.requires_grad = False

    return finetune_net
# 查看网络模型
get_net(devices=d2l.try_all_gpus())

在这里插入图片描述

7. 定义损失函数和评价损失函数

# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')

def evaluate_loss(data_iter, net, device):
    l_sum, n = 0.0, 0
    for features, labels in data_iter:
        features, labels = features.to(device[0]), labels.to(device[0])
        outputs = net(features)
        l = loss(outputs, labels)
        l_sum += l.sum()
        n += labels.numel()
        return (l_sum / n).to('cpu')
  1. 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):
    # 只训练小型定义输出网络
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    trainer = torch.optim.SGD(
        (param for param in net.parameters() if param.requires_grad),
        lr=lr, momentum=0.9, weight_decay=wd
    )
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss']
    if valid_iter is not None:
        legend.append('valid loss')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(2)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            features, labels = features.to(devices[0]), labels.to(devices[0])
            trainer.zero_grad()
            output = net(features)
            l = loss(output, labels).sum()
            l.backward()
            trainer.step()
            metric.add(l, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(
                    epoch + (i + 1) / num_batches, (metric[0] / metric[1], None)
                )
        measures = f'train loss {metric[0] / metric[1]:.3f}'
        if valid_iter is not None :
            valid_loss = evaluate_loss(valid_iter, net, devices)
            animator.add(epoch + 1, (None, valid_loss.detach().cpu()))
        scheduler.step()
    if valid_iter is not None:
        measures += f', valid loss {valid_loss:.3f}'
    print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'
                     f'examples/sec on {str(devices)}')

9. 训练模型

devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, = 2, 0.9, get_net(devices)
import time

# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以

train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)

# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以

# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/176125.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Figma最全面的新手指南,从基础到高级,一网打尽

1 Figma界面介绍 Figma基础界面与传统设计软件没有太大区别,有Sketch使用经验的用户几乎可以无缝连接到Figma。 立即体验 免费的在线Figma汉化版即时设计是一款支持在线协作的专业级 UI 设计工具,支持 Sketch、Figma、XD 格式导入,海量优质设…

【计算机网络】多路复用的三种方案

文章目录 1. selectselect函数select的工作特性select的缺点 2. pollpoll函数poll与select的对比 3. epollepoll的三个接口epoll的工作原理epoll的优点LT和ET模式epoll的应用场景 🔎Linux提供三种不同的多路转接(又称多路复用)的方案&#xf…

图解Spark Graphx基于connectedComponents函数实现连通图底层原理

原创/朱季谦 第一次写这么长的graphx源码解读,还是比较晦涩,有较多不足之处,争取改进。 一、连通图说明 连通图是指图中的任意两个顶点之间都存在路径相连而组成的一个子图。 用一个图来说明,例如,下面这个叫graph…

sklearn模型中预测值的R2_score为负数

目录 正文评论区参考链接 正文 Sklearn.metrics下面的r2_score函数用于计算R(确定系数:coefficient of determination)。它用来度量未来的样本是否可能通过模型被很好地预测。 分值为 1 表示最好,但我们在使用过程中&#xff0c…

数据分层:打造数据资产管家

一、引言 随着企业数据规模的增长,数据的价值变得越来越重要。然而,传统的数据库在承载大量数据时面临挑战,需要高效有序的维护。因此,建立高效的数据仓库成为了企业决策和管理的基石,但现代技术的背景下,…

HUAWEI华为MateBook X Pro 2022 12代酷睿版(MRGF-16)笔记本电脑原装出厂Windows11系统工厂模式含F10还原

链接:https://pan.baidu.com/s/1ZI5mR6SOgFzMljbMym7u3A?pwdl2cu 提取码:l2cu 华为原厂Windows11系统工厂包,带F10一键智能还原恢复功能。 自带指纹、面部识别、声卡、网卡、显卡、蓝牙等所有驱动、出厂主题壁纸、Office办公软件、华为…

OpenCvSharp从入门到实践-(02)图像处理的基本操作

目录 图像处理的基础操作 1、读取图像 1.1、读取当前目录下的图像 2、显示图像 2.1、Cv2.ImShow 用于显示图像。 2.2、Cv2.WaitKey方法用于等待用户按下键盘上按键的时间。 2.3、Cv2.DestroyAllWindows方法用于销毁所有正在显示图像的窗口。 2.4实例1-显示图像 2.4实例…

数据结构与算法编程题8

试编写算法将带头结点的单链表就地逆置&#xff0c;所谓“就地”是指空间复杂度为 O(1)。 #include <iostream> using namespace std;typedef int Elemtype; #define ERROR 0; #define OK 1;typedef struct LNode {Elemtype data; //结点保存的数据struct LNode…

windows11记事本应用程序无法打开,未响应,崩溃,卡死

windows11记事本应用程序无法打开&#xff0c;未响应&#xff0c;崩溃&#xff0c;卡死 文章目录 问题描述搜索引擎&#xff08;度娘&#xff09;卸载后如何安装问题未解决另一个解决方案&#xff1a;步骤&#xff1a;1.设置 → 语音和区域 → 输入2.选择“高级键盘设置”3.替…

C语言中的多线程调用

功能 开启一个线程&#xff0c;不断打印传进去的参数&#xff0c;并且每次打印后自增1 代码 #include<windows.h> #include<pthread.h> #include<stdio.h>void* print(void *a) {int *ic(int*)a;float *fc(float*)(asizeof(int)*2);double *dc(double*)(as…

黑苹果入门:资源、安装、使用、问题、必备工具、驱动篇

黑苹果入门&#xff1a;资源、安装、使用、问题、必备工具、驱动篇 一. 黑苹果入门&#xff1a;安装使用篇资源篇安装篇AMD处理器(桌面级)可以安装黑苹果macOS吗&#xff1f;黑苹果跑码是什么意思&#xff1f;进入语言选择界面&#xff0c;鼠标/键盘无法使用&#xff1f;安装完…

HashMap知识点总结

文章目录 HashMapConcurrentHashMap线程安全问题 HashMap 1、null作为key只能有一个&#xff0c;作为value可以有多个 2、容量&#xff1a; 1.7&#xff1a;默认161.8&#xff1a;初始化并未指定容量大小&#xff0c;第一次put才初始化容量 3、负载因子 默认0.75&#xff0…

代码文档浏览器 Dash mac中文版软件特色

Dash mac是一个基于 Python 的 web 应用程序框架&#xff0c;它可以帮助开发者快速构建数据可视化应用。Dash 的工作原理是将 Python 代码转换成 HTML、CSS 和 JavaScript&#xff0c;从而在浏览器中呈现交互式的数据可视化界面。Dash 提供了一系列组件&#xff0c;包括图表、表…

HarmonyOS ArkTS语言,运行Hello World(一)

一、下载与安装DevEco Studio 在HarmonyOS应用开发学习之前&#xff0c;需要进行一些准备工作&#xff0c;首先需要完成开发工具DevEco Studio的下载与安装以及环境配置。 进入DevEco Studio下载官网&#xff0c;单击“立即下载”进入下载页面。 DevEco Studio提供了Windows…

动态规划求 x 轴上相距最远的两个相邻点 java 代码实现

如图为某一状态下 x 轴上的情况&#xff0c;此时 E、F相距最远&#xff0c;现在加入一个点H&#xff0c;如果H位于点A的左边的话&#xff0c;只需要比较 A、H 的距离 和 E、F 的距离&#xff1b;如果点H位于点G的右边&#xff0c;则值需要比较 G、H 的距离 和 E、F 的距离&…

Docker安装Rabbitmq3.12并且prometheus进行监听【亲测可用】

一、安装Rabbitmq 下载镜像&#xff1a; docker pull rabbitmq:3.12-management 安装镜像&#xff1a; docker run -id --restartalways --namerabbitmq -v /usr/local/rabbitmq:/var/lib/rabbitmq -p 15692:15692 -p 15672:15672 -p 5672:5672 -e RABBITMQ_DEFAULT_USERgu…

在线接口测试工具fastmock使用

1、fastmock线上数据模拟器 在平时的项目测试中&#xff0c;尤其是前后端分离的时候&#xff0c;前端人员需要测试调用后端的接口&#xff0c;这个时候会出现测试不方便的情况。此时我们可以使用fastmock平台在线上模拟出一个可以调用的接口&#xff0c;方便前端人员进行数据测…

linux服务器安装gitlab

一、安装gitlab sudo yum install curl policycoreutils-python openssh-server openssh-clients sudo systemctl enable sshd sudo systemctl start sshd sudo firewall-cmd --permanent --add-servicehttp curl https://packages.gitlab.com/install/repositories/gitla…

OpenAI乱局幕后大佬浮出水面:Quora联合创始人

丨划重点 ● Quora德安杰洛在过去的周末积极游说圈内科技领袖出任OpenAI首席执行官。 ● 在奥特曼与OpenAI董事会关于重返公司的谈判中&#xff0c;德安杰洛是真正的主角。 ● Quora前员工透露&#xff0c;德安杰洛性格倔强&#xff0c;很难被说服。 ● Quora之前开发了人工…

c++ 谓词

1. 一元谓词 #include <iostream> #include<vector> #include<algorithm>using namespace std;class CreaterFive{ public:bool operator()(int val){return val>5;} };int main() {vector<int> vec;for(int i0; i<10; i){vec.push_back(i);}ve…
最新文章