第七章 目标检测模型搭建,训练,预测


阅读本文需要有基础的pytorch编程经验,目标检测框架相关知识,不用很深入,大致了解概念即可。

本章简要介绍如何如何用C++实现一个目标检测器模型,该模型具有训练和预测的功能。本文的分割模型架构使用yolov4-tiny结构,代码结构参考了bubbliiiing yolov4-tiny,本文分享的c++模型几乎完美复现了pytorch的版本,且具有速度优势,30-40%的速度提升。

1.模型简介

简单介绍一下yolov4-tiny模型。yolov4-tiny模型是YOLO(you only look once)系列模型中,version 4的轻巧版,相比于yolov4,它牺牲了部分精度以实现速度上的大幅提升。yolov4_tiny模型结构如图(图片来源自这):


可以发现模型结构非常简单,以CSPDarknet53-tiny为骨干网络,FPN为颈部(neck),Yolo head为头部。最后输出两个特征层,分别是原图下采样32倍和下采样16倍的特征图。训练时,以这两个特征图分别输入损失计算中计算损失,再将损失求和(或平均,怎么都好),后做反向传播,预测时将两个特征图解码出的结果做并集再做NMS(非极大值抑制)。

2.骨干网络

CSPDarknet53-tiny是CSPNet的一种,CSPNet发表于CVPR2019,是用于提升目标检测模型检测性能的一种骨干网络,但对于分类效果提升有限,但在速度上有提升。感兴趣的同学可以去看原文,简单理解该论文贡献,就是将特征层沿着通道维度切成两片,两片分别做不同的卷积,然后再拼接起来,这样做相比于直接对原图做特征提取,能减少计算量。

默认看过我的libtorch系列教程的前部分,直接上代码。首先是基本单元,由Conv2d + BatchNorm2d + LeakyReLU构成。

//Conv2d + BatchNorm2d + LeakyReLU
class BasicConvImpl : public torch::nn::Module {
public:
    BasicConvImpl(int in_channels, int out_channels, int kernel_size, int stride = 1);
    torch::Tensor forward(torch::Tensor x);
private:
    // Declare layers
    torch::nn::Conv2d conv{ nullptr };
    torch::nn::BatchNorm2d bn{ nullptr };
    torch::nn::LeakyReLU acitivation{ nullptr };
}; TORCH_MODULE(BasicConv);

BasicConvImpl::BasicConvImpl(int in_channels, int out_channels, int kernel_size, 
    int stride) :
    conv(conv_options(in_channels, out_channels, kernel_size, stride, 
        int(kernel_size / 2), 1, false)),
    bn(torch::nn::BatchNorm2d(out_channels)),
    acitivation(torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.1)))
{
    register_module("conv", conv);
    register_module("bn", bn);
}

torch::Tensor BasicConvImpl::forward(torch::Tensor x)
{
    x = conv->forward(x);
    x = bn->forward(x);
    x = acitivation(x);
    return x;
}

该层作为基本模块,将在后期作为搭积木的基本块,搭建yolo4_tiny。

然后是Resblock_body模块,

class Resblock_bodyImpl : public torch::nn::Module {
public:
    Resblock_bodyImpl(int in_channels, int out_channels);
    std::vector<torch::Tensor> forward(torch::Tensor x);
private:
    int out_channels;
    BasicConv conv1{ nullptr };
    BasicConv conv2{ nullptr };
    BasicConv conv3{ nullptr };
    BasicConv conv4{ nullptr };
    torch::nn::MaxPool2d maxpool{ nullptr };
}; TORCH_MODULE(Resblock_body);

Resblock_bodyImpl::Resblock_bodyImpl(int in_channels, int out_channels) {
    this->out_channels = out_channels;
    conv1 = BasicConv(in_channels, out_channels, 3);
    conv2 = BasicConv(out_channels / 2, out_channels / 2, 3);
    conv3 = BasicConv(out_channels / 2, out_channels / 2, 3);
    conv4 = BasicConv(out_channels, out_channels, 1);
    maxpool = torch::nn::MaxPool2d(maxpool_options(2, 2));

    register_module("conv1", conv1);
    register_module("conv2", conv2);
    register_module("conv3", conv3);
    register_module("conv4", conv4);

}
std::vector<torch::Tensor> Resblock_bodyImpl::forward(torch::Tensor x) {
    auto c = out_channels;
    x = conv1->forward(x);
    auto route = x;

    x = torch::split(x, c / 2, 1)[1];
    x = conv2->forward(x);
    auto route1 = x;

    x = conv3->forward(x);
    x = torch::cat({ x, route1 }, 1);
    x = conv4->forward(x);
    auto feat = x;

    x = torch::cat({ route, x }, 1);
    x = maxpool->forward(x);
    return std::vector<torch::Tensor>({ x,feat });
}

最后是骨干网络主体

class CSPdarknet53_tinyImpl : public torch::nn::Module
{
public:
    CSPdarknet53_tinyImpl();
    std::vector<torch::Tensor> forward(torch::Tensor x);
private:
    BasicConv conv1{ nullptr };
    BasicConv conv2{ nullptr };
    Resblock_body resblock_body1{ nullptr };
    Resblock_body resblock_body2{ nullptr };
    Resblock_body resblock_body3{ nullptr };
    BasicConv conv3{ nullptr };
    int num_features = 1;
}; TORCH_MODULE(CSPdarknet53_tiny);

CSPdarknet53_tinyImpl::CSPdarknet53_tinyImpl() {
    conv1 = BasicConv(3, 32, 3, 2);
    conv2 = BasicConv(32, 64, 3, 2);
    resblock_body1 = Resblock_body(64, 64);
    resblock_body2 = Resblock_body(128, 128);
    resblock_body3 = Resblock_body(256, 256);
    conv3 = BasicConv(512, 512, 3);

    register_module("conv1", conv1);
    register_module("conv2", conv2);
    register_module("resblock_body1", resblock_body1);
    register_module("resblock_body2", resblock_body2);
    register_module("resblock_body3", resblock_body3);
    register_module("conv3", conv3);
}

std::vector<torch::Tensor> CSPdarknet53_tinyImpl::forward(torch::Tensor x) {
    // 416, 416, 3 -> 208, 208, 32 -> 104, 104, 64
    x = conv1(x);
    x = conv2(x);

    // 104, 104, 64 -> 52, 52, 128
    x = resblock_body1->forward(x)[0];
    // 52, 52, 128 -> 26, 26, 256
    x = resblock_body2->forward(x)[0];
    // 26, 26, 256->xΪ13, 13, 512
#   //        -> feat1Ϊ26,26,256
    auto res_out = resblock_body3->forward(x);
    x = res_out[0];
    auto feat1 = res_out[1];
    // 13, 13, 512 -> 13, 13, 512
    x = conv3->forward(x);
    auto feat2 = x;
    return std::vector<torch::Tensor>({ feat1, feat2 });
}

至此,yolo4_tiny中的骨干网络已经搭建好。接下来将搭建yolo4_tiny模型。

3.yolov4_tiny

骨干网络得到的特征图,将经过FPN,需要上采样模块。

//conv+upsample
class UpsampleImpl : public torch::nn::Module {
public:
    UpsampleImpl(int in_channels, int out_channels);
    torch::Tensor forward(torch::Tensor x);
private:
    // Declare layers
    torch::nn::Sequential upsample = torch::nn::Sequential();
}; TORCH_MODULE(Upsample);

UpsampleImpl::UpsampleImpl(int in_channels, int out_channels)
{
    upsample = torch::nn::Sequential(
        BasicConv(in_channels, out_channels, 1)
        //torch::nn::Upsample(torch::nn::UpsampleOptions().scale_factor(std::vector<double>({ 2 })).mode(torch::kNearest).align_corners(false))
    );
    register_module("upsample", upsample);
}

torch::Tensor UpsampleImpl::forward(torch::Tensor x)
{
    x = upsample->forward(x);
    x = at::upsample_nearest2d(x, { x.sizes()[2] * 2 , x.sizes()[3] * 2 });
    return x;
}

然后是yolo_head模块

torch::nn::Sequential yolo_head(std::vector<int> filters_list, int in_filters);

torch::nn::Sequential yolo_head(std::vector<int> filters_list, int in_filters) {
    auto m = torch::nn::Sequential(BasicConv(in_filters, filters_list[0], 3),
        torch::nn::Conv2d(conv_options(filters_list[0], filters_list[1], 1)));
    return m;
}

以及yolo_body

class YoloBody_tinyImpl : public torch::nn::Module {
public:
    YoloBody_tinyImpl(int num_anchors, int num_classes);
    std::vector<torch::Tensor> forward(torch::Tensor x);
private:
    // Declare layers
    CSPdarknet53_tiny backbone{ nullptr };
    BasicConv conv_for_P5{ nullptr };
    Upsample upsample{ nullptr };
    torch::nn::Sequential yolo_headP5{ nullptr };
    torch::nn::Sequential yolo_headP4{ nullptr };
}; TORCH_MODULE(YoloBody_tiny);

YoloBody_tinyImpl::YoloBody_tinyImpl(int num_anchors, int num_classes) {
    backbone = CSPdarknet53_tiny();
    conv_for_P5 = BasicConv(512, 256, 1);
    yolo_headP5 = yolo_head({ 512, num_anchors * (5 + num_classes) }, 256);
    upsample = Upsample(256, 128);
    yolo_headP4 = yolo_head({ 256, num_anchors * (5 + num_classes) }, 384);

    register_module("backbone", backbone);
    register_module("conv_for_P5", conv_for_P5);
    register_module("yolo_headP5", yolo_headP5);
    register_module("upsample", upsample);
    register_module("yolo_headP4", yolo_headP4);
}
std::vector<torch::Tensor> YoloBody_tinyImpl::forward(torch::Tensor x) {
    //return feat1 with shape of {26,26,256} and feat2 of {13, 13, 512}
    auto backbone_out = backbone->forward(x);
    auto feat1 = backbone_out[0];
    auto feat2 = backbone_out[1];
    //13,13,512 -> 13,13,256
    auto P5 = conv_for_P5->forward(feat2);
    //13, 13, 256 -> 13, 13, 512 -> 13, 13, 255
    auto out0 = yolo_headP5->forward(P5);


    //13,13,256 -> 13,13,128 -> 26,26,128
    auto P5_Upsample = upsample->forward(P5);
    //26, 26, 256 + 26, 26, 128 -> 26, 26, 384
    auto P4 = torch::cat({ P5_Upsample, feat1 }, 1);
    //26, 26, 384 -> 26, 26, 256 -> 26, 26, 255
    auto out1 = yolo_headP4->forward(P4);
    return std::vector<torch::Tensor>({ out0, out1 });
}

代码写到这一步,其实只要细心就会发现基本是对pytorch代码到libtorch的迁移,除了少数bug需要调试,大部分简单迁移到c++即可。可以说是非常简便了。

像前面章节中一样,生成torchscript模型。bubbliiiing yolov4-tiny中有提供一个coco训练版本,通过下述代码生成.pt文件:

import torch
from torchsummary import summary
import numpy as np

from nets.yolo4_tiny import YoloBody
from train import get_anchors, get_classes,YOLOLoss

device = torch.device('cpu')
model = YoloBody(3,80).to(device)
model_path = "model_data/yolov4_tiny_weights_coco.pth"

print('Loading weights into state dict...')
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=torch.device("cpu"))
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print('Finished!')

#生成pt模型,按照官网来即可
model=model.to(torch.device("cpu"))
model.eval()
var=torch.ones((1,3,416,416))
traced_script_module = torch.jit.trace(model, var)
traced_script_module.save("yolo4_tiny.pt")

然后在c++中使用下述代码测试是否能够正确加载:

auto model = YoloBody_tiny(3, 80);
torch::load(model, "weights/yolo4_tiny.pt");

执行通过即表明加载成功。

4.预测

预测需要将YOLO4_tiny模型输出的张量进行解码,根据源代码解码函数,写出c++版本的解码函数,此时将发现,libtorch教程第二章的重要性了。

torch::Tensor DecodeBox(torch::Tensor input, torch::Tensor anchors, int num_classes, int img_size[])
{
    int num_anchors = anchors.sizes()[0];
    int bbox_attrs = 5 + num_classes;
    int batch_size = input.sizes()[0];
    int input_height = input.sizes()[2];
    int input_width = input.sizes()[3];
    //计算步长
    //每一个特征点对应原来的图片上多少个像素点
    //如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点
    //416 / 13 = 32
    auto stride_h = img_size[1] / input_height;
    auto stride_w = img_size[0] / input_width;
    //把先验框的尺寸调整成特征层大小的形式
    //计算出先验框在特征层上对应的宽高
    auto scaled_anchors = anchors.clone();
    scaled_anchors.select(1, 0) = scaled_anchors.select(1, 0) / stride_w;
    scaled_anchors.select(1, 1) = scaled_anchors.select(1, 1) / stride_h;

    //bs, 3 * (5 + num_classes), 13, 13->bs, 3, 13, 13, (5 + num_classes)
    //cout << "begin view"<<input.sizes()<<endl;
    auto prediction = input.view({ batch_size, num_anchors,bbox_attrs, input_height, input_width }).permute({ 0, 1, 3, 4, 2 }).contiguous();
    //cout << "end view" << endl;
    //先验框的中心位置的调整参数
    auto x = torch::sigmoid(prediction.select(-1, 0));
    auto y = torch::sigmoid(prediction.select(-1, 1));
    //先验框的宽高调整参数
    auto w = prediction.select(-1, 2); // Width
    auto h = prediction.select(-1, 3); // Height

    //获得置信度,是否有物体
    auto conf = torch::sigmoid(prediction.select(-1, 4));
    //种类置信度
    auto pred_cls = torch::sigmoid(prediction.narrow(-1, 5, num_classes));// Cls pred.

    auto LongType = x.clone().to(torch::kLong).options();
    auto FloatType = x.options();

    //生成网格,先验框中心,网格左上角 batch_size, 3, 13, 13
    auto grid_x = torch::linspace(0, input_width - 1, input_width).repeat({ input_height, 1 }).repeat(
        { batch_size * num_anchors, 1, 1 }).view(x.sizes()).to(FloatType);
    auto grid_y = torch::linspace(0, input_height - 1, input_height).repeat({ input_width, 1 }).t().repeat(
        { batch_size * num_anchors, 1, 1 }).view(y.sizes()).to(FloatType);

    //生成先验框的宽高
    auto anchor_w = scaled_anchors.to(FloatType).narrow(1, 0, 1);
    auto anchor_h = scaled_anchors.to(FloatType).narrow(1, 1, 1);
    anchor_w = anchor_w.repeat({ batch_size, 1 }).repeat({ 1, 1, input_height * input_width }).view(w.sizes());
    anchor_h = anchor_h.repeat({ batch_size, 1 }).repeat({ 1, 1, input_height * input_width }).view(h.sizes());

    //计算调整后的先验框中心与宽高
    auto pred_boxes = torch::randn_like(prediction.narrow(-1, 0, 4)).to(FloatType);
    pred_boxes.select(-1, 0) = x + grid_x;
    pred_boxes.select(-1, 1) = y + grid_y;
    pred_boxes.select(-1, 2) = w.exp() * anchor_w;
    pred_boxes.select(-1, 3) = h.exp() * anchor_h;

    //用于将输出调整为相对于416x416的大小
    std::vector<int> scales{ stride_w, stride_h, stride_w, stride_h };
    auto _scale = torch::tensor(scales).to(FloatType);
    //cout << pred_boxes << endl;
    //cout << conf << endl;
    //cout << pred_cls << endl;
    pred_boxes = pred_boxes.view({ batch_size, -1, 4 }) * _scale;
    conf = conf.view({ batch_size, -1, 1 });
    pred_cls = pred_cls.view({ batch_size, -1, num_classes });
    auto output = torch::cat({ pred_boxes, conf, pred_cls }, -1);
    return output;
}

此外,还需要将输出进行非极大值抑制。参考我的NMS的几种写法写出非极大值抑制函数:

```c++