Skip to content

Commit d5c5f43

Browse files
committed
update training
1 parent d1476c5 commit d5c5f43

File tree

6 files changed

+19
-26
lines changed

6 files changed

+19
-26
lines changed

backbone_nets/mobilenetv2_backbone.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from torch import nn
3-
#from .utils import load_state_dict_from_url
4-
3+
from torch.hub import load_state_dict_from_url
54

65
__all__ = ['MobileNetV2', 'mobilenet_v2']
76

@@ -205,5 +204,5 @@ def mobilenet_v2(pretrained=False, progress=True, **kwargs):
205204
if pretrained:
206205
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
207206
progress=progress)
208-
model.load_state_dict(state_dict)
207+
model.load_state_dict(state_dict, strict=False)
209208
return model

loss_definition.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import torch.nn as nn
3-
from math import sqrt
43
from utils.params import ParamsPack
54
param_pack = ParamsPack()
65
import math
@@ -33,20 +32,13 @@ def __init__(self):
3332
super(ParamLoss, self).__init__()
3433
self.criterion = nn.MSELoss(reduction="none")
3534

36-
def forward(self, input, target, mode='normal'):
35+
def forward(self, input, target, mode = 'normal'):
3736
if mode == 'normal':
3837
loss = self.criterion(input[:,:12], target[:,:12]).mean(1) + self.criterion(input[:,12:], target[:,12:]).mean(1)
3938
return torch.sqrt(loss)
40-
if mode == 'no_tex':
41-
loss = self.criterion(input[:,:12], target[:,:12]).mean(1) + self.criterion(input[:,12:62], target[:,12:62]).mean(1)
42-
return torch.sqrt(loss)
43-
if mode == 'only_3dmm':
39+
elif mode == 'only_3dmm':
4440
loss = self.criterion(input[:,:50], target[:,12:62]).mean(1)
4541
return torch.sqrt(loss)
46-
if mode == 'only_pose':
47-
loss = self.criterion(input[:,:12], target[:,:12]).mean(1)
48-
return loss
49-
5042
return torch.sqrt(loss.mean(1))
5143

5244

main_train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def parse_args():
5151
parser.add_argument('--test_initial', default='false', type=str2bool)
5252
parser.add_argument('--warmup', default=-1, type=int)
5353
parser.add_argument('--param-fp-train',default='',type=str)
54-
parser.add_argument('--data_ver', default='v1', type=str)
5554
parser.add_argument('--img_size', default=120, type=int)
5655
parser.add_argument('--save_val_freq', default=10, type=int)
5756

@@ -225,7 +224,7 @@ def main():
225224

226225
filename = f'{args.snapshot}_checkpoint_epoch_{epoch}.pth.tar'
227226
# save checkpoints and current model validation
228-
if (epoch % args.save_val_freq == 0) or (epoch==args.epochs) or ((epoch >= 45) and (epoch % 2 ==0)):
227+
if (epoch % args.save_val_freq == 0) or (epoch==args.epochs):
229228
save_checkpoint(
230229
{
231230
'epoch': epoch,

model_building.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ def __init__(self, args):
3333
self.args = args
3434
# backbone definition
3535
if 'mobilenet_v2' in self.args.arch:
36-
self.backbone = getattr(mobilenetv2_backbone, args.arch)()
36+
self.backbone = getattr(mobilenetv2_backbone, args.arch)(pretrained=False)
3737
elif 'mobilenet' in self.args.arch:
3838
self.backbone = getattr(mobilenetv1_backbone, args.arch)()
3939
elif 'resnet' in self.args.arch:
4040
self.backbone = getattr(resnet_backbone, args.arch)(pretrained=False)
4141
elif 'ghostnet' in self.args.arch:
4242
self.backbone = getattr(ghostnet_backbone, args.arch)()
43+
else:
44+
raise RuntimeError("Please choose [mobilenet_v2, mobilenet_1, resnet50, or ghostnet]")
4345

4446
def forward(self,input, target):
4547
"""Training time forward"""
@@ -102,7 +104,7 @@ def reconstruct_vertex_62(self, param, whitening=True, dense=False, transform=Tr
102104
transform: whether transform to image space
103105
Working with batched tensors. Using Fortan-type reshape.
104106
"""
105-
107+
106108
if whitening:
107109
if param.shape[1] == 62:
108110
param_ = param * self.param_std[:62] + self.param_mean[:62]
@@ -134,16 +136,16 @@ def forward(self, input, target):
134136

135137
vertex_lmk = self.reconstruct_vertex_62(_3D_attr, dense=False)
136138
vertex_GT_lmk = self.reconstruct_vertex_62(_3D_attr_GT, dense=False)
137-
self.loss['loss_LMK_f0'] = 0.01 *self.LMKLoss_3D(vertex_lmk, vertex_GT_lmk, kp=True)
139+
self.loss['loss_LMK_f0'] = 0.05 *self.LMKLoss_3D(vertex_lmk, vertex_GT_lmk, kp=True)
138140
self.loss['loss_Param_In'] = 0.02 * self.ParamLoss(_3D_attr, _3D_attr_GT)
139-
141+
140142
point_residual = self.forwardDirection(vertex_lmk, avgpool, _3D_attr[:,12:52], _3D_attr[:,52:62])
141-
vertex_lmk = vertex_lmk + 0.01 * point_residual
142-
self.loss['loss_LMK_pointNet'] = 0.01 * self.LMKLoss_3D(vertex_lmk, vertex_GT_lmk, kp=True)
143+
vertex_lmk = vertex_lmk + 0.05 * point_residual
144+
self.loss['loss_LMK_pointNet'] = 0.05 * self.LMKLoss_3D(vertex_lmk, vertex_GT_lmk, kp=True)
143145

144146
_3D_attr_S2 = self.reverseDirection(vertex_lmk)
145147
self.loss['loss_Param_S2'] = 0.02 * self.ParamLoss(_3D_attr_S2, _3D_attr_GT, mode='only_3dmm')
146-
self.loss['loss_Param_S1S2'] = 1 * self.ParamLoss(_3D_attr_S2, _3D_attr, mode='only_3dmm')
148+
self.loss['loss_Param_S1S2'] = 0.001 * self.ParamLoss(_3D_attr_S2, _3D_attr, mode='only_3dmm')
147149

148150
return self.loss
149151

train_script.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ python3 main_train.py --arch="mobilenet_v2" \
1111
--snapshot="ckpts/SynergyNet" \
1212
--param-fp-train='./3dmm_data/param_all_norm_v201.pkl' \
1313
--warmup=5 \
14-
--batch-size=350 \
15-
--base-lr=0.01 \
16-
--epochs=80 \
17-
--milestones=48,64 \
14+
--batch-size=900 \
15+
--base-lr=0.027 \
16+
--epochs=50 \
17+
--milestones=30,40 \
1818
--print-freq=50 \
1919
--devices-id=0 \
2020
--workers=8 \

utils/ddfa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def __call__(self, img, gt=None):
228228
if not(_is_tensor_image(img)):
229229
raise TypeError('img should be tensor. Got {}'.format(type(img)))
230230
if img.ndim == 3:
231+
crop_backgnd[:, crop_margins:h-1*crop_margins, crop_margins:w-1*crop_margins] = img[:, crop_margins: h-crop_margins, crop_margins: w-crop_margins]
231232
# random center crop
232233
if (rand < self.prob) and (self.mode=='train'):
233234
func = self.switcher.get(random.randint(1,7))

0 commit comments

Comments
 (0)