计算机视觉MAE-的项目实战从图像重建到目标检测
计算机视觉|MAE 的项目实战:从图像重建到目标检测
一、引言
上一期文章 带大家走进计算机视觉的热门话题——MAE(Masked Autoencoders) 。俗话说:“光说不练假把式”。今天就带使用 MAE 进行图像重建和目标检测。如果你是个 Python 小白,别怕,我会用通俗的语言一步步带你入门。我们不仅会实现一个简单的图像重建项目,还会扩展到目标检测的实战,让你从零开始感受 MAE 的强大之处。准备好了吗?Let’s go!
二、什么是 MAE?它为什么重要?
MAE 全称是 Masked Autoencoders
,也就是“掩码自编码器”。简单来说,它的工作方式就像玩拼图:把一张图片的一部分“遮住”(掩码),然后让模型去“猜”被遮住的部分是什么。MAE 是 2021
年何恺明等人提出的一种自监督学习方法,基于 Transformer 架构,通过“掩码+重建”的方式,让模型学会理解图片的深层特征。
它的核心流程是:
- 随机遮住图片的 75%(或更多)区域。
- 用编码器(Encoder)处理剩下的可见部分。
- 用解码器(Decoder)尝试把遮住的部分补回来。
- 通过对比重建结果和原图,优化模型。 MAE 的厉害之处在于,它不需要大量标注数据,就能学到强大的图像表示。这对目标检测、图像分类等任务来说,是个很好的起点。好了,理论讲完,咱们直接上手干活!
三、项目实战 1:用 MAE 重建一张图片
我们先从一个简单的图像重建项目开始,帮你熟悉 MAE 的基本流程。
1 环境准备
确保你安装了以下库:
- Python 3.x
- PyTorch
- torchvision
- matplotlib 安装命令: pip install torch torchvision matplotlib
2 代码实现
下面是代码,带详细注释: import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt
1. 加载和预处理图片
def load_image(image_path): image = Image.open(image_path).convert(‘RGB’) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) return transform(image).unsqueeze(0)
2. 定义简化的 MAE 模型
class SimpleMAE(nn.Module): def init(self, patch_size=16, mask_ratio=0.25): # 降低掩码比例 super(SimpleMAE, self).init() self.patch_size = patch_size self.mask_ratio = mask_ratio
编码器
self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU() )
解码器
self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
去掉 Sigmoid,直接输出
) def forward(self, x): encoded = self.encoder(x) mask = torch.rand(encoded.shape) > self.mask_ratio masked_encoded = encoded * mask.to(encoded.device) reconstructed = self.decoder(masked_encoded) return reconstructed
3. 伪训练函数
def pseudo_train(model, image, epochs=10): optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.MSELoss() # 均方误差损失 for _ in range(epochs): optimizer.zero_grad() reconstructed = model(image) loss = criterion(reconstructed, image) loss.backward() optimizer.step() return model
4. 主函数
def main(): image_path = “example.jpg” # 替换为你的图片路径 original_image = load_image(image_path) device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”) model = SimpleMAE().to(device) original_image = original_image.to(device)
伪训练模型
model = pseudo_train(model, original_image, epochs=50) # 跑 50 次简单拟合
前向传播
with torch.no_grad(): reconstructed_image = model(original_image)
显示结果
plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.title(“Original Image”) plt.imshow(original_image.squeeze(0).permute(1, 2, 0).cpu().numpy()) plt.axis(“off”) plt.subplot(1, 2, 2) plt.title(“Reconstructed Image”) plt.imshow(reconstructed_image.squeeze(0).permute(1, 2, 0).cpu().numpy()) plt.axis(“off”) plt.show() if name == “main”: main()
3 运行与结果
把一张图片(比如 example.jpg
)放在代码目录下,运行后会显示原图和重建图。第一次运行效果可能不好,因为模型没训练,但这能帮你理解 MAE
的基本原理。
四、项目实战 2:用 MAE 助力目标检测
图像重建只是 MAE 的“热身”,它的真正实力在于为下游任务(如目标检测)提供预训练模型。我们接下来实现一个简单的目标检测项目,用 MAE 提取特征,再结合一个检测头,识别图片中的对象。
1 目标检测简介
目标检测的任务是找到图片中的对象,并标注出它们的位置(用边界框表示)和类别(比如“猫”“狗”)。经典模型有 [YOLO](https://bthvi- leiqi.blog.csdn.net/article/details/145852743)、 等。我们这里用一个极简版,基于 MAE 的特征提取。
2 环境准备
除了之前的库,我们还需要 torchvision
的检测工具:
pip install torch torchvision matplotlib
3 代码实现
我们会在 MAE 的基础上加一个检测头,预测边界框和类别: import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as patches
1. 加载图片
def load_image(image_path, target_size=224): image = Image.open(image_path).convert(‘RGB’) transform = transforms.Compose([ transforms.Resize((target_size, target_size)), # 直接缩放到 224224 transforms.ToTensor() ]) return transform(image).unsqueeze(0), (800, 800) # 返回原始尺寸 800800
2. 定义带检测头的 MAE 模型
class MAEWithDetection(nn.Module): def init(self, patch_size=16, mask_ratio=0.25, num_classes=2): super(MAEWithDetection, self).init() self.patch_size = patch_size self.mask_ratio = mask_ratio self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU() ) self.detection_head = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(256, 4 + num_classes, kernel_size=1) ) def forward(self, x): encoded = self.encoder(x) mask = torch.rand(encoded.shape) > self.mask_ratio masked_encoded = encoded * mask.to(encoded.device) detection_output = self.detection_head(masked_encoded) return detection_output
3. 伪训练函数
def pseudo_train(model, image, target, epochs=100): optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.MSELoss() for _ in range(epochs): optimizer.zero_grad() output = model(image) loss = criterion(output, target) loss.backward() optimizer.step() return model
4. 创建伪标注(基于 800*800,猫的位置)
def create_pseudo_target(image_shape, feature_shape, orig_size): target = torch.zeros(feature_shape) # [1, 6, H, W] h, w = feature_shape[2], feature_shape[3] # 5656 orig_w, orig_h = orig_size # 800800
猫的边界框 (x, y, w, h) = (350, 170, 570, 580)
cat_x, cat_y, cat_w, cat_h = 350, 170, 570, 580 target_h = int(cat_y / orig_h * h) # 映射到特征图高度 target_w = int(cat_x / orig_w * w) # 映射到特征图宽度 target[0, 0, target_h, target_w] = cat_x / orig_w * w # x target[0, 1, target_h, target_w] = cat_y / orig_h * h # y target[0, 2, target_h, target_w] = cat_w / orig_w * w # w target[0, 3, target_h, target_w] = cat_h / orig_h * h # h target[0, 4, target_h, target_w] = 1.0 # 猫的类别得分 target[0, 5, target_h, target_w] = 0.0 # 狗的类别得分 return target
5. 可视化检测结果(映射回 800*800)
def visualize_detection(image, detection_output, orig_size): fig, ax = plt.subplots(1) ax.imshow(image.squeeze(0).permute(1, 2, 0).cpu().numpy()) detection = detection_output[0].cpu().detach().numpy() h, w = detection.shape[1:] # 5656 orig_w, orig_h = orig_size # 800800 scale_h, scale_w = orig_h / h, orig_w / w # 缩放系数
取类别得分最高的位置
cls_scores = detection[4:, :, :] # [2, H, W] max_score = cls_scores.max() max_idx = cls_scores.argmax() cls_idx = max_idx // (h * w) # 0: 猫, 1: 狗 max_h = (max_idx % (h * w)) // w max_w = (max_idx % (h * w)) % w if max_score > 0.1: x, y, w_box, h_box = detection[:4, max_h, max_w] rect = patches.Rectangle( (x * scale_w, y * scale_h), w_box * scale_w, h_box * scale_h, linewidth=2, edgecolor=‘r’ if cls_idx == 0 else ‘b’, facecolor=‘none’ ) ax.add_patch(rect) ax.text(x * scale_w, y * scale_h, ‘Cat’ if cls_idx == 0 else ‘Dog’, color=‘r’ if cls_idx == 0 else ‘b’) plt.title(“Detected Objects”) plt.axis(“off”) plt.show()
6. 主函数
def main(): image_path = “cat.jpeg” # 替换为你的 800*800 图片路径 original_image, orig_size = load_image(image_path) device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”) model = MAEWithDetection(num_classes=2).to(device) original_image = original_image.to(device)
创建伪标注并伪训练
feature_shape = (1, 6, 56, 56) # 检测头输出形状 target = create_pseudo_target(original_image.shape, feature_shape, orig_size).to(device) model = pseudo_train(model, original_image, target, epochs=100)
前向传播
with torch.no_grad(): detection_output = model(original_image)
可视化
visualize_detection(original_image, detection_output, orig_size) if name == “main”: main()
4 代码解释
- 模型结构 :我们复用了 MAE 的编码器,然后加了个检测头,输出边界框
(x, y, w, h)
和类别得分。 - 检测输出 :
detection_output
的形状是[batch, 4+num_classes, H, W]
,每个位置预测一个边界框和类别。 - 可视化 :我们用红框标注检测结果,置信度大于 0.5 的框会被显示。
5 运行与结果
运行代码后,你会看到图片上的猫头标注了红色的检测框。由于模型未训练,框的位置和类别可能不准确。实际应用中,你需要:
- 用标注数据(如 COCO 数据集)训练检测头。
- 用 MAE 预训练编码器,再微调整个模型。
五、MAE 的进阶与应用
通过这两个实战,你应该对 MAE 的潜力有了初步认识。MAE 的真正威力在于:
- 图像重建 :补全缺失部分。
- 目标检测 :提取特征,提升检测精度。
- 其他任务 :图像分类、分割等。
想深入学习?可以参考 MAE 原论文
,或用 PyTorch 的
VisionTransformer
实现完整版 MAE。
六、写在最后
从图像重建到目标检测,我们用 Python 和 PyTorch 完成了两个小项目。作为 Python 小白,你已经迈出了重要一步!接下来可以试试:
- 用更大的数据集训练模型。
- 结合预训练的 ,提升效果。 有问题随时问我,我是 紫雾凌寒,乐意帮你解答。
延伸阅读