Skip to content

Commit f7007b6

Browse files
committed
Resnet2060!
updated gradient checkpoint for training resnet2060
1 parent 676b02d commit f7007b6

File tree

5 files changed

+189
-11
lines changed

5 files changed

+189
-11
lines changed

recognition/arcface_torch/backbones/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2-
from .mobilefacenet import MobileFaceNet
32

43

54
def get_model(name, **kwargs):
@@ -13,7 +12,8 @@ def get_model(name, **kwargs):
1312
return iresnet100(False, **kwargs)
1413
elif name == "r200":
1514
return iresnet200(False, **kwargs)
16-
elif name == "mbf":
17-
return MobileFaceNet((112, 112), **kwargs)
15+
elif name == "r2060":
16+
from .iresnet2060 import iresnet2060
17+
return iresnet2060(False, **kwargs)
1818
else:
1919
raise ValueError()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import torch
2+
from torch import nn
3+
4+
assert torch.__version__ >= "1.8.1"
5+
from torch.utils.checkpoint import checkpoint_sequential
6+
7+
__all__ = ['iresnet2060']
8+
9+
10+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11+
"""3x3 convolution with padding"""
12+
return nn.Conv2d(in_planes,
13+
out_planes,
14+
kernel_size=3,
15+
stride=stride,
16+
padding=dilation,
17+
groups=groups,
18+
bias=False,
19+
dilation=dilation)
20+
21+
22+
def conv1x1(in_planes, out_planes, stride=1):
23+
"""1x1 convolution"""
24+
return nn.Conv2d(in_planes,
25+
out_planes,
26+
kernel_size=1,
27+
stride=stride,
28+
bias=False)
29+
30+
31+
class IBasicBlock(nn.Module):
32+
expansion = 1
33+
34+
def __init__(self, inplanes, planes, stride=1, downsample=None,
35+
groups=1, base_width=64, dilation=1):
36+
super(IBasicBlock, self).__init__()
37+
if groups != 1 or base_width != 64:
38+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39+
if dilation > 1:
40+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42+
self.conv1 = conv3x3(inplanes, planes)
43+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44+
self.prelu = nn.PReLU(planes)
45+
self.conv2 = conv3x3(planes, planes, stride)
46+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47+
self.downsample = downsample
48+
self.stride = stride
49+
50+
def forward(self, x):
51+
identity = x
52+
out = self.bn1(x)
53+
out = self.conv1(out)
54+
out = self.bn2(out)
55+
out = self.prelu(out)
56+
out = self.conv2(out)
57+
out = self.bn3(out)
58+
if self.downsample is not None:
59+
identity = self.downsample(x)
60+
out += identity
61+
return out
62+
63+
64+
class IResNet(nn.Module):
65+
fc_scale = 7 * 7
66+
67+
def __init__(self,
68+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
69+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70+
super(IResNet, self).__init__()
71+
self.fp16 = fp16
72+
self.inplanes = 64
73+
self.dilation = 1
74+
if replace_stride_with_dilation is None:
75+
replace_stride_with_dilation = [False, False, False]
76+
if len(replace_stride_with_dilation) != 3:
77+
raise ValueError("replace_stride_with_dilation should be None "
78+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79+
self.groups = groups
80+
self.base_width = width_per_group
81+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83+
self.prelu = nn.PReLU(self.inplanes)
84+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85+
self.layer2 = self._make_layer(block,
86+
128,
87+
layers[1],
88+
stride=2,
89+
dilate=replace_stride_with_dilation[0])
90+
self.layer3 = self._make_layer(block,
91+
256,
92+
layers[2],
93+
stride=2,
94+
dilate=replace_stride_with_dilation[1])
95+
self.layer4 = self._make_layer(block,
96+
512,
97+
layers[3],
98+
stride=2,
99+
dilate=replace_stride_with_dilation[2])
100+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101+
self.dropout = nn.Dropout(p=dropout, inplace=True)
102+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104+
nn.init.constant_(self.features.weight, 1.0)
105+
self.features.weight.requires_grad = False
106+
107+
for m in self.modules():
108+
if isinstance(m, nn.Conv2d):
109+
nn.init.normal_(m.weight, 0, 0.1)
110+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111+
nn.init.constant_(m.weight, 1)
112+
nn.init.constant_(m.bias, 0)
113+
114+
if zero_init_residual:
115+
for m in self.modules():
116+
if isinstance(m, IBasicBlock):
117+
nn.init.constant_(m.bn2.weight, 0)
118+
119+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120+
downsample = None
121+
previous_dilation = self.dilation
122+
if dilate:
123+
self.dilation *= stride
124+
stride = 1
125+
if stride != 1 or self.inplanes != planes * block.expansion:
126+
downsample = nn.Sequential(
127+
conv1x1(self.inplanes, planes * block.expansion, stride),
128+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129+
)
130+
layers = []
131+
layers.append(
132+
block(self.inplanes, planes, stride, downsample, self.groups,
133+
self.base_width, previous_dilation))
134+
self.inplanes = planes * block.expansion
135+
for _ in range(1, blocks):
136+
layers.append(
137+
block(self.inplanes,
138+
planes,
139+
groups=self.groups,
140+
base_width=self.base_width,
141+
dilation=self.dilation))
142+
143+
return nn.Sequential(*layers)
144+
145+
def checkpoint(self, func, num_seg, x):
146+
if self.training:
147+
return checkpoint_sequential(func, num_seg, x)
148+
else:
149+
return func(x)
150+
151+
def forward(self, x):
152+
with torch.cuda.amp.autocast(self.fp16):
153+
x = self.conv1(x)
154+
x = self.bn1(x)
155+
x = self.prelu(x)
156+
x = self.layer1(x)
157+
x = self.checkpoint(self.layer2, 20, x)
158+
x = self.checkpoint(self.layer3, 100, x)
159+
x = self.layer4(x)
160+
x = self.bn2(x)
161+
x = torch.flatten(x, 1)
162+
x = self.dropout(x)
163+
x = self.fc(x.float() if self.fp16 else x)
164+
x = self.features(x)
165+
return x
166+
167+
168+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169+
model = IResNet(block, layers, **kwargs)
170+
if pretrained:
171+
raise ValueError()
172+
return model
173+
174+
175+
def iresnet2060(pretrained=False, progress=True, **kwargs):
176+
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)

recognition/arcface_torch/docs/install.md

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
## v1.8.1
2+
### Linux and Windows
3+
```shell
4+
# CUDA 10.2
5+
pip3 install torch torchvision torchaudio
6+
```
7+
8+
19
## v1.7.1
210
### Linux and Windows
311
```shell

recognition/arcface_torch/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def main(args):
4949

5050
dropout = 0.4 if cfg.dataset == "webface" else 0
5151
backbone = get_model(args.network, dropout=dropout, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
52-
backbone_onnx = get_model(args.network, dropout=dropout, fp16=False, num_features=cfg.embedding_size)
5352

5453
if args.resume:
5554
try:
@@ -121,7 +120,7 @@ def main(args):
121120
loss.update(loss_v, 1)
122121
callback_logging(global_step, loss, epoch, cfg.fp16, grad_amp)
123122
callback_verification(global_step, backbone)
124-
callback_checkpoint(global_step, backbone, module_partial_fc, backbone_onnx)
123+
callback_checkpoint(global_step, backbone, module_partial_fc)
125124
scheduler_backbone.step()
126125
scheduler_pfc.step()
127126
dist.destroy_process_group()

recognition/arcface_torch/utils/utils_callbacks.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import torch
77

88
from eval import verification
9-
from partial_fc import PartialFC
10-
from torch2onnx import convert_onnx
119
from utils.utils_logging import AverageMeter
1210

1311

@@ -100,14 +98,11 @@ def __init__(self, rank, output="./"):
10098
self.rank: int = rank
10199
self.output: str = output
102100

103-
def __call__(self, global_step, backbone, partial_fc, backbone_onnx):
101+
def __call__(self, global_step, backbone, partial_fc,):
104102
if global_step > 100 and self.rank is 0:
105103
path_module = os.path.join(self.output, "backbone.pth")
106-
path_onnx = os.path.join(self.output, "backbone.onnx")
107104
torch.save(backbone.module.state_dict(), path_module)
108105
logging.info("Pytorch Model Saved in '{}'".format(path_module))
109-
convert_onnx(backbone_onnx, path_module, path_onnx)
110-
logging.info("Onnx Model Saved in '{}'".format(path_onnx))
111106

112107
if global_step > 100 and partial_fc is not None:
113108
partial_fc.save_params()

0 commit comments

Comments
 (0)