目录

第15章ConvNeXt图像分类实战遥感场景分类包含本地网页部署迁移学习

第15章:ConvNeXt图像分类实战:遥感场景分类【包含本地网页部署、迁移学习】


1. ConvNeXt 模型

ConvNeXt是一种基于卷积神经网络(CNN)的现代架构,由Facebook AI Research (FAIR) 团队在2022年提出。它通过借鉴Transformer的设计思想,对传统CNN进行了改进,使其在图像分类等任务中表现优异,甚至超越了Vision Transformers(ViT)

https://i-blog.csdnimg.cn/direct/bad600fdd2b14f1a8dfcefda3e259daa.png

详细介绍:

核心思想

ConvNeXt的核心思想是将Transformer的成功设计理念(如ViT)引入CNN,同时保留卷积的固有优势。通过一系列现代化改进,ConvNeXt在保持高效性的同时,提升了性能。

主要改进

  1. 大卷积核 :使用更大的卷积核(如7x7)来扩大感受野,类似于Transformer中自注意力机制捕捉全局信息的能力。
  2. 分层设计 :采用类似ResNet的分层结构,逐步降低分辨率并增加通道数,以提取多尺度特征。
  3. 倒置瓶颈结构 :借鉴MobileNetV2的倒置瓶颈设计,先扩展通道数再进行深度卷积,最后压缩通道数,提升计算效率。
  4. Layer Normalization :用Layer Normalization替换Batch Normalization,更适合小批量训练,并提升模型稳定性。
  5. GELU激活函数 :使用GELU激活函数替代ReLU,因其在Transformer中的表现更佳。
  6. 减少激活和归一化层 :减少不必要的激活和归一化层,简化网络结构,提升性能。
  7. Stochastic Depth :引入随机深度(Stochastic Depth),在训练时随机丢弃部分层,增强模型泛化能力。

2. 遥感场景建筑识别

ConvNeXt 实现的model部分代码如下面所示,这里如果采用官方预训练权重的话,会自动导入官方提供的最新版本(ImageNet)的权重

https://i-blog.csdnimg.cn/direct/1dc7880bc36244dd91a5209306baee10.png

2.1 数据集

数据集文件如下:

https://i-blog.csdnimg.cn/direct/e5098f1f6cee455293e082315f13a399.png

具体图像示例:

https://i-blog.csdnimg.cn/direct/b3b487334b8b4ce79ec39b241b4b8edd.png

标签如下:

{
    "0": "airport",
    "1": "bridge",
    "2": "church",
    "3": "forest",
    "4": "lake",
    "5": "river",
    "6": "skyscraper",
    "7": "stadium",
    "8": "statue",
    "9": "tower",
    "10": "urbanPark"
}

其中,训练集的总数为820,验证集的总数为345

https://i-blog.csdnimg.cn/direct/53acb5f04e744df0a3ff8e517f64755d.png

2.2 训练参数

训练的参数如下:

    parser.add_argument("--model", default='tiny', type=str,help='tiny,small,base,large')
    parser.add_argument("--pretrained", default=True, type=bool)       # 采用官方权重
    parser.add_argument("--freeze_layers", default=True, type=bool)    # 冻结权重

    parser.add_argument("--batch-size", default=8, type=int)
    parser.add_argument("--epochs", default=30, type=int)

    parser.add_argument("--optim", default='AdamW', type=str,help='SGD,Adam,AdamW')         # 优化器选择

    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lrf',default=0.01,type=float)                  # 最终学习率 = lr * lrf

    parser.add_argument('--save_ret', default='runs', type=str)             # 保存结果
    parser.add_argument('--data_train',default='./data/train',type=str)           # 训练集路径
    parser.add_argument('--data_val',default='./data/val',type=str)               # 验证集路径

需要注意的是网络分类的个数不需要指定,摆放好数据集后,代码会根据数据集自动生成!

更换数据集的话,将data-train和data-val路径更改即可,一键运行!

trian脚本会在训练同时自动验证,生成训练和验证的曲线图和指标

网络模型信息如下:

 "train parameters": {
        "model version": "tiny",
        "pretrained": true,
        "freeze_layers": true,
        "batch_size": 8,
        "epochs": 30,
        "optim": "AdamW",
        "lr": 0.01,
        "lrf": 0.01,
        "save_folder": "runs"
    },
    "dataset": {
        "trainset number": 820,
        "valset number": 345,
        "number classes": 11
    },
    "model": {
        "total parameters": 27818891.0,
        "train parameters": 9995,
        "flops": 4463390208.0
    },

2.3 训练结果

所有的结果都保存在 save_ret 目录下,这里是 runs

weights 下有最好和最后的权重, 在训练完成后控制台会打印最好的epoch

https://i-blog.csdnimg.cn/direct/d733e33596ba4a19852853485fa65bde.png

这里只展示部分结果: 可以看到网络没有完全收敛,增大epoch会得到更好的效果

https://i-blog.csdnimg.cn/direct/71eae0d543b4477e8562c76496d7717f.png

https://i-blog.csdnimg.cn/direct/641db0fbc2004b5591f49c7a2da8384a.png

最后一轮结果:

    "epoch:29": {
        "train info": {
            "accuracy": 0.9987804877926978,
            "airport": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "bridge": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "church": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "forest": {
                "Precision": 1.0,
                "Recall": 0.987,
                "Specificity": 1.0,
                "F1 score": 0.9935
            },
            "lake": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "river": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "skyscraper": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "stadium": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "statue": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "tower": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "urbanPark": {
                "Precision": 0.9872,
                "Recall": 1.0,
                "Specificity": 0.9987,
                "F1 score": 0.9936
            },
            "mean precision": 0.9988363636363636,
            "mean recall": 0.9988181818181818,
            "mean specificity": 0.9998818181818181,
            "mean f1 score": 0.9988272727272729
        },
        "valid info": {
            "accuracy": 0.8463768115696703,
            "airport": {
                "Precision": 0.9231,
                "Recall": 0.8571,
                "Specificity": 0.997,
                "F1 score": 0.8889
            },
            "bridge": {
                "Precision": 0.9032,
                "Recall": 0.8485,
                "Specificity": 0.9904,
                "F1 score": 0.875
            },
            "church": {
                "Precision": 0.7647,
                "Recall": 0.8125,
                "Specificity": 0.9878,
                "F1 score": 0.7879
            },
            "forest": {
                "Precision": 0.9259,
                "Recall": 0.7576,
                "Specificity": 0.9936,
                "F1 score": 0.8333
            },
            "lake": {
                "Precision": 0.9091,
                "Recall": 0.7692,
                "Specificity": 0.997,
                "F1 score": 0.8333
            },
            "river": {
                "Precision": 0.7872,
                "Recall": 0.9024,
                "Specificity": 0.9671,
                "F1 score": 0.8409
            },
            "skyscraper": {
                "Precision": 0.875,
                "Recall": 0.7,
                "Specificity": 0.997,
                "F1 score": 0.7778
            },
            "stadium": {
                "Precision": 0.8943,
                "Recall": 0.9649,
                "Specificity": 0.9437,
                "F1 score": 0.9283
            },
            "statue": {
                "Precision": 0.8095,
                "Recall": 0.6296,
                "Specificity": 0.9874,
                "F1 score": 0.7083
            },
            "tower": {
                "Precision": 0.7143,
                "Recall": 0.4167,
                "Specificity": 0.994,
                "F1 score": 0.5263
            },
            "urbanPark": {
                "Precision": 0.7,
                "Recall": 0.875,
                "Specificity": 0.9617,
                "F1 score": 0.7778
            },
            "mean precision": 0.8369363636363637,
            "mean recall": 0.7757727272727273,
            "mean specificity": 0.9833363636363637,
            "mean f1 score": 0.7979818181818181
        }
    }

训练集和测试集的混淆矩阵:

https://i-blog.csdnimg.cn/direct/792035d8ff414cea818f009b8a3ea77a.png

https://i-blog.csdnimg.cn/direct/c699b603418144629dd724fc4c47b47f.png

ROC曲线和auc值:

https://i-blog.csdnimg.cn/direct/30933361653d41d3bc3f9dbadf102005.png

2.4 本地部署推理

推理是指没有标签,只有图片数据的情况下对数据的预测,这里使用了网页推理

值得注意的是,如果训练了自己的数据集,需要对infer脚本进行更改,如下:

# 参数
MODEL = 'tiny'
LABELS = r'D:\project\ConvNeXt全家桶\runs\class_indices.json'
PTH = r'D:\project\ConvNeXt全家桶\runs\weights\best.pth'
IMAGE_PATH = r'D:\project\ConvNeXt全家桶\data\train\airport\0.jpg'

运行:

streamlit run D:\project\ConvNeXt全家桶\infer.py

https://i-blog.csdnimg.cn/direct/a851d9545406459d8f3bbab2000f3d4e.png

3. 下载

关于本项目代码和数据集、训练结果的下载:

https://i-blog.csdnimg.cn/direct/b9e21b43a6a64a9ab7b25c7f2a2b5b59.png

关于Ai 深度学习图像识别、医学图像分割改进系列:

神经网络改进完整实战项目: