Skip to content

Commit ebe6d75

Browse files
committed
[cnn_architecture]: Add starting blocks to experiment with
Currently, single run only
1 parent a9ff117 commit ebe6d75

File tree

11 files changed

+344
-0
lines changed

11 files changed

+344
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
data
99
*.pkl
1010
*.onnx
11+
*.model
12+
*.log

cnn_architecture/__init__.py

Whitespace-only changes.

cnn_architecture/blocks.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch.nn.functional as F
2+
from torch import Tensor, nn
3+
4+
5+
# Complexity of Conv2d, without feature_map_size: kernel_size^2*in_channels*out_channels
6+
7+
# All blocks should have equal or lower computational complexity than simple block
8+
# Count multiplications only as an approximation
9+
# Total_complexity: 9*in*hidden + 3*9*hidden^2
10+
class SimpleBlock(nn.Module):
11+
def __init__(self, in_channels: int):
12+
super().__init__()
13+
self._num_hidden = in_channels * 2
14+
self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1)
15+
self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden, 3, padding=1)
16+
self._conv3 = nn.Conv2d(self._num_hidden, self._num_hidden, 3, padding=1)
17+
18+
def forward(self, batch: Tensor) -> Tensor:
19+
batch = F.relu(self._conv1(batch))
20+
batch = F.relu(self._conv2(batch))
21+
batch = F.relu(self._conv3(batch))
22+
return batch
23+
24+
@property
25+
def num_hidden(self):
26+
return self._num_hidden
27+
28+
29+
# Total complexity: 9*in*hidden + 3*9*(hidden*hidden/2+hidden/2*hidden) + 3*relu
30+
class AddBlock(nn.Module):
31+
def __init__(self, in_channels: int):
32+
super().__init__()
33+
self._num_hidden = in_channels * 2
34+
self._stride = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1)
35+
36+
self._conv1 = nn.Conv2d(self._num_hidden, self._num_hidden // 2, 3, padding=1)
37+
self._adj1 = nn.Conv2d(self._num_hidden // 2, self._num_hidden, 3, padding=1)
38+
39+
self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden // 2, 3, padding=1)
40+
self._adj2 = nn.Conv2d(self._num_hidden // 2, self._num_hidden, 3, padding=1)
41+
42+
def forward(self, batch: Tensor) -> Tensor:
43+
batch = F.relu(self._stride(batch))
44+
batch = batch + self._adj1(F.relu(self._conv1(batch)))
45+
batch = batch + self._adj2(F.relu(self._conv2(batch)))
46+
return batch
47+
48+
@property
49+
def num_hidden(self):
50+
return self._num_hidden
51+
52+
53+
# BN Complexity: 2
54+
class BNPreBlock(nn.Module):
55+
def __init__(self, in_channels: int):
56+
super().__init__()
57+
self._num_hidden = in_channels * 2
58+
self._bn1 = nn.BatchNorm2d(in_channels)
59+
self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1)
60+
self._bn2 = nn.BatchNorm2d(self._num_hidden)
61+
self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden - 1, 3, padding=1)
62+
self._bn3 = nn.BatchNorm2d(self._num_hidden - 1)
63+
self._conv3 = nn.Conv2d(self._num_hidden - 1, self._num_hidden, 3, padding=1)
64+
65+
def forward(self, batch: Tensor) -> Tensor:
66+
batch = F.relu(self._conv1(self._bn1(batch)))
67+
batch = F.relu(self._conv2(self._bn2(batch)))
68+
batch = F.relu(self._conv3(self._bn3(batch)))
69+
return batch
70+
71+
@property
72+
def num_hidden(self):
73+
return self._num_hidden
74+
75+
76+
class BNBetweenBlock(nn.Module):
77+
def __init__(self, in_channels: int):
78+
super().__init__()
79+
self._num_hidden = in_channels * 2
80+
self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1)
81+
self._bn1 = nn.BatchNorm2d(self._num_hidden)
82+
self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden - 1, 3, padding=1)
83+
self._bn2 = nn.BatchNorm2d(self._num_hidden - 1)
84+
self._conv3 = nn.Conv2d(self._num_hidden - 1, self._num_hidden, 3, padding=1)
85+
self._bn3 = nn.BatchNorm2d(self._num_hidden)
86+
87+
def forward(self, batch: Tensor) -> Tensor:
88+
batch = F.relu(self._bn1(self._conv1(batch)))
89+
batch = F.relu(self._bn2(self._conv2(batch)))
90+
batch = F.relu(self._bn3(self._conv3(batch)))
91+
return batch
92+
93+
@property
94+
def num_hidden(self):
95+
return self._num_hidden
96+
97+
98+
class BNPostBlock(nn.Module):
99+
def __init__(self, in_channels: int):
100+
super().__init__()
101+
self._num_hidden = in_channels * 2
102+
self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1)
103+
self._bn1 = nn.BatchNorm2d(self._num_hidden)
104+
self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden - 1, 3, padding=1)
105+
self._bn2 = nn.BatchNorm2d(self._num_hidden - 1)
106+
self._conv3 = nn.Conv2d(self._num_hidden - 1, self._num_hidden, 3, padding=1)
107+
self._bn3 = nn.BatchNorm2d(self._num_hidden)
108+
109+
def forward(self, batch: Tensor) -> Tensor:
110+
batch = self._bn1(F.relu(self._conv1(batch)))
111+
batch = self._bn2(F.relu(self._conv2(batch)))
112+
batch = self._bn3(F.relu(self._conv3(batch)))
113+
return batch
114+
115+
@property
116+
def num_hidden(self):
117+
return self._num_hidden

cnn_architecture/config.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import torch
5+
from yacs.config import CfgNode as CN
6+
7+
_cfg = CN()
8+
_cfg.NAME = ''
9+
_cfg.OUTPUT_DIR = ''
10+
_cfg.SEED = 42
11+
12+
_cfg.MODEL = CN()
13+
_cfg.MODEL.CONV0 = CN()
14+
_cfg.MODEL.CONV0.IN_CHANNELS = 3
15+
_cfg.MODEL.CONV0.NUM_FILTERS = 128
16+
_cfg.MODEL.CONV0.SIZE = 7
17+
_cfg.MODEL.CONV0.STRIDE = 2
18+
19+
_cfg.MODEL.NUM_CLASSES = 100
20+
_cfg.MODEL.BLOCK_TYPE = 'simple'
21+
22+
_cfg.TRAIN = CN()
23+
_cfg.TRAIN.LR = 0.1
24+
_cfg.TRAIN.MOMENTUM = 0.9
25+
_cfg.TRAIN.WEIGHT_DECAY = 0.0005
26+
_cfg.TRAIN.NUM_ITERS = 10_000
27+
_cfg.TRAIN.BATCH_SIZE = 256
28+
_cfg.TRAIN.SHUFFLE = True
29+
30+
_cfg.VAL = CN()
31+
_cfg.VAL.BATCH_SIZE = 0
32+
33+
34+
def load_cfg(cfg_path: Path):
35+
cfg = _cfg.clone()
36+
cfg.merge_from_file(cfg_path)
37+
cfg = _transform_cfg(cfg, cfg_path.stem)
38+
39+
Path(cfg.OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
40+
_update_seeds(cfg.SEED)
41+
return cfg
42+
43+
44+
def _transform_cfg(cfg: CN, name: str):
45+
if cfg.NAME == '':
46+
cfg.NAME = name
47+
if cfg.OUTPUT_DIR == '':
48+
cfg.OUTPUT_DIR = f"output/{cfg.NAME}"
49+
if cfg.VAL.BATCH_SIZE == 0:
50+
cfg.VAL.BATCH_SIZE = cfg.TRAIN.BATCH_SIZE * 2
51+
return cfg
52+
53+
54+
def _update_seeds(seed: int):
55+
torch.manual_seed(seed)
56+
np.random.seed(seed)

cnn_architecture/configs/add.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
MODEL:
2+
BLOCK_TYPE: 'add'
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
MODEL:
2+
BLOCK_TYPE: 'bn_between'

cnn_architecture/configs/bn_post.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
MODEL:
2+
BLOCK_TYPE: 'bn_post'

cnn_architecture/configs/bn_pre.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
MODEL:
2+
BLOCK_TYPE: 'bn_pre'

cnn_architecture/configs/simple.yaml

Whitespace-only changes.

cnn_architecture/model_builder.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch.nn.functional as F
2+
from torch import nn
3+
from yacs.config import CfgNode as CN
4+
5+
from blocks import AddBlock, BNBetweenBlock, BNPostBlock, BNPreBlock, SimpleBlock
6+
7+
8+
class ModelTemplate(nn.Module):
9+
def __init__(self, block_type, conv_cfg: CN, num_classes: int):
10+
super().__init__()
11+
self._conv = nn.Conv2d(conv_cfg.IN_CHANNELS, conv_cfg.NUM_FILTERS, conv_cfg.SIZE, stride=conv_cfg.STRIDE,
12+
padding=conv_cfg.SIZE // 2)
13+
14+
self._block1 = block_type(conv_cfg.NUM_FILTERS)
15+
self._block2 = block_type(self._block1.num_hidden)
16+
self._cls = nn.Linear(self._block2.num_hidden, num_classes)
17+
18+
def forward(self, batch):
19+
batch = F.relu(self._conv(batch))
20+
batch = self._block1(batch)
21+
batch = self._block2(batch)
22+
batch = F.max_pool2d(batch, batch.shape[-2:])
23+
batch = self._cls(batch.view(batch.shape[:2]))
24+
return batch
25+
26+
27+
_block_factory = {
28+
'simple': SimpleBlock,
29+
'add': AddBlock,
30+
'bn_pre': BNPreBlock,
31+
'bn_between': BNBetweenBlock,
32+
'bn_post': BNPostBlock
33+
}
34+
35+
36+
def create_model(model_cfg: CN):
37+
return ModelTemplate(_block_factory[model_cfg.BLOCK_TYPE], model_cfg.CONV0, model_cfg.NUM_CLASSES)

0 commit comments

Comments
 (0)