1 Star 0 Fork 1

北京邮电大学-侯璐/deeplabv3+

forked from wanglin/deeplabv3+ 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
Deeplab.py 6.15 KB
一键复制 编辑 原始数据 按行查看 历史
wanglin 提交于 2023-03-08 19:15 +08:00 . add Transformer_channal to aspp
from __future__ import absolute_import, print_function
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
class _ImagePool(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv = _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1)
def forward(self, x):
_, _, H, W = x.shape
h = self.pool(x)
h = self.conv(h)
h = F.interpolate(h, size=(H, W), mode="bilinear", align_corners=False)
return h
class _ASPP(nn.Module):
"""
Atrous spatial pyramid pooling with image-level feature
"""
def __init__(self, in_ch, out_ch, rates):
super(_ASPP, self).__init__()
self.stages = nn.Module()
self.stages.add_module("c0", _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1))
for i, rate in enumerate(rates):
self.stages.add_module(
"c{}".format(i + 1),
_ConvBnReLU(in_ch, out_ch, 3, 1, padding=rate, dilation=rate),
)
self.stages.add_module("imagepool", _ImagePool(in_ch, out_ch))
def forward(self, x):
return torch.cat([stage(x) for stage in self.stages.children()], dim=1)
class DeepLabV3Plus(nn.Module):
"""
DeepLab v3+: Dilated ResNet with multi-grid + improved ASPP + decoder
"""
def __init__(self, n_classes, n_blocks, atrous_rates, multi_grids, output_stride):
super(DeepLabV3Plus, self).__init__()
# Stride and dilation
if output_stride == 8:
s = [1, 2, 1, 1]
d = [1, 1, 2, 4]
elif output_stride == 16:
s = [1, 2, 2, 1]
d = [1, 1, 1, 2]
# Encoder
ch = [64 * 2 ** p for p in range(6)]
self.layer1 = _Stem(ch[0])
self.layer2 = _ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0])
self.layer3 = _ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1])
self.layer4 = _ResLayer(n_blocks[2], ch[3], ch[4], s[2], d[2])
self.layer5 = _ResLayer(n_blocks[3], ch[4], ch[5], s[3], d[3], multi_grids)
self.aspp = _ASPP(ch[5], 256, atrous_rates)
concat_ch = 256 * (len(atrous_rates) + 2)
self.add_module("fc1", _ConvBnReLU(concat_ch, 256, 1, 1, 0, 1))
# Decoder
self.reduce = _ConvBnReLU(256, 48, 1, 1, 0, 1)
self.fc2 = nn.Sequential(
OrderedDict(
[
("conv1", _ConvBnReLU(304, 256, 3, 1, 1, 1)),
("conv2", _ConvBnReLU(256, 256, 3, 1, 1, 1)),
("conv3", nn.Conv2d(256, n_classes, kernel_size=1)),
]
)
)
def forward(self, x):
h = self.layer1(x)
h = self.layer2(h)
h_ = self.reduce(h)
h = self.layer3(h)
h = self.layer4(h)
h = self.layer5(h)
h = self.aspp(h)
h = self.fc1(h)
h = F.interpolate(h, size=h_.shape[2:], mode="bilinear", align_corners=False)
h = torch.cat((h, h_), dim=1)
h = self.fc2(h)
h = F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)
return h
try:
from torch.nn import SyncBatchNorm
_BATCH_NORM = SyncBatchNorm
except:
_BATCH_NORM = nn.BatchNorm2d
_BOTTLENECK_EXPANSION = 4
class _ConvBnReLU(nn.Sequential):
"""
Cascade of 2D convolution, batch norm, and ReLU.
"""
BATCH_NORM = _BATCH_NORM
def __init__(
self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True
):
super(_ConvBnReLU, self).__init__()
self.add_module(
"conv",
nn.Conv2d(
in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False
),
)
self.add_module("bn", _BATCH_NORM(out_ch, eps=1e-5, momentum=0.999))
if relu:
self.add_module("relu", nn.ReLU())
class _Bottleneck(nn.Module):
"""
Bottleneck block of MSRA ResNet.
"""
def __init__(self, in_ch, out_ch, stride, dilation, downsample):
super(_Bottleneck, self).__init__()
mid_ch = out_ch // _BOTTLENECK_EXPANSION
self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True)
self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True)
self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False)
self.shortcut = (
_ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False)
if downsample
else lambda x: x # identity
)
def forward(self, x):
h = self.reduce(x)
h = self.conv3x3(h)
h = self.increase(h)
h += self.shortcut(x)
return F.relu(h)
class _ResLayer(nn.Sequential):
"""
Residual layer with multi grids
"""
def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None):
super(_ResLayer, self).__init__()
if multi_grids is None:
multi_grids = [1 for _ in range(n_layers)]
else:
assert n_layers == len(multi_grids)
# Downsampling is only in the first block
for i in range(n_layers):
self.add_module(
"block{}".format(i + 1),
_Bottleneck(
in_ch=(in_ch if i == 0 else out_ch),
out_ch=out_ch,
stride=(stride if i == 0 else 1),
dilation=dilation * multi_grids[i],
downsample=(True if i == 0 else False),
),
)
class _Stem(nn.Sequential):
"""
The 1st conv layer.
Note that the max pooling is different from both MSRA and FAIR ResNet.
"""
def __init__(self, out_ch):
super(_Stem, self).__init__()
self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1))
self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True))
if __name__ == "__main__":
model = DeepLabV3Plus(
n_classes=21,
n_blocks=[3, 4, 23, 3],
atrous_rates=[6, 12, 18],
multi_grids=[1, 2, 4],
output_stride=16,
)
model.eval()
image = torch.randn(1, 3, 513, 513)
image.cuda()
print(model)
print("input:", image.shape)
print("output:", model(image).shape)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/bupt_mobile_challenge/deeplabv3.git
[email protected]:bupt_mobile_challenge/deeplabv3.git
bupt_mobile_challenge
deeplabv3
deeplabv3+
master

搜索帮助