|
| 1 | +import torch.nn.functional as F |
| 2 | +from torch import Tensor, nn |
| 3 | + |
| 4 | + |
| 5 | +# Complexity of Conv2d, without feature_map_size: kernel_size^2*in_channels*out_channels |
| 6 | + |
| 7 | +# All blocks should have equal or lower computational complexity than simple block |
| 8 | +# Count multiplications only as an approximation |
| 9 | +# Total_complexity: 9*in*hidden + 3*9*hidden^2 |
| 10 | +class SimpleBlock(nn.Module): |
| 11 | + def __init__(self, in_channels: int): |
| 12 | + super().__init__() |
| 13 | + self._num_hidden = in_channels * 2 |
| 14 | + self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1) |
| 15 | + self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden, 3, padding=1) |
| 16 | + self._conv3 = nn.Conv2d(self._num_hidden, self._num_hidden, 3, padding=1) |
| 17 | + |
| 18 | + def forward(self, batch: Tensor) -> Tensor: |
| 19 | + batch = F.relu(self._conv1(batch)) |
| 20 | + batch = F.relu(self._conv2(batch)) |
| 21 | + batch = F.relu(self._conv3(batch)) |
| 22 | + return batch |
| 23 | + |
| 24 | + @property |
| 25 | + def num_hidden(self): |
| 26 | + return self._num_hidden |
| 27 | + |
| 28 | + |
| 29 | +# Total complexity: 9*in*hidden + 3*9*(hidden*hidden/2+hidden/2*hidden) + 3*relu |
| 30 | +class AddBlock(nn.Module): |
| 31 | + def __init__(self, in_channels: int): |
| 32 | + super().__init__() |
| 33 | + self._num_hidden = in_channels * 2 |
| 34 | + self._stride = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1) |
| 35 | + |
| 36 | + self._conv1 = nn.Conv2d(self._num_hidden, self._num_hidden // 2, 3, padding=1) |
| 37 | + self._adj1 = nn.Conv2d(self._num_hidden // 2, self._num_hidden, 3, padding=1) |
| 38 | + |
| 39 | + self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden // 2, 3, padding=1) |
| 40 | + self._adj2 = nn.Conv2d(self._num_hidden // 2, self._num_hidden, 3, padding=1) |
| 41 | + |
| 42 | + def forward(self, batch: Tensor) -> Tensor: |
| 43 | + batch = F.relu(self._stride(batch)) |
| 44 | + batch = batch + self._adj1(F.relu(self._conv1(batch))) |
| 45 | + batch = batch + self._adj2(F.relu(self._conv2(batch))) |
| 46 | + return batch |
| 47 | + |
| 48 | + @property |
| 49 | + def num_hidden(self): |
| 50 | + return self._num_hidden |
| 51 | + |
| 52 | + |
| 53 | +# BN Complexity: 2 |
| 54 | +class BNPreBlock(nn.Module): |
| 55 | + def __init__(self, in_channels: int): |
| 56 | + super().__init__() |
| 57 | + self._num_hidden = in_channels * 2 |
| 58 | + self._bn1 = nn.BatchNorm2d(in_channels) |
| 59 | + self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1) |
| 60 | + self._bn2 = nn.BatchNorm2d(self._num_hidden) |
| 61 | + self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden - 1, 3, padding=1) |
| 62 | + self._bn3 = nn.BatchNorm2d(self._num_hidden - 1) |
| 63 | + self._conv3 = nn.Conv2d(self._num_hidden - 1, self._num_hidden, 3, padding=1) |
| 64 | + |
| 65 | + def forward(self, batch: Tensor) -> Tensor: |
| 66 | + batch = F.relu(self._conv1(self._bn1(batch))) |
| 67 | + batch = F.relu(self._conv2(self._bn2(batch))) |
| 68 | + batch = F.relu(self._conv3(self._bn3(batch))) |
| 69 | + return batch |
| 70 | + |
| 71 | + @property |
| 72 | + def num_hidden(self): |
| 73 | + return self._num_hidden |
| 74 | + |
| 75 | + |
| 76 | +class BNBetweenBlock(nn.Module): |
| 77 | + def __init__(self, in_channels: int): |
| 78 | + super().__init__() |
| 79 | + self._num_hidden = in_channels * 2 |
| 80 | + self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1) |
| 81 | + self._bn1 = nn.BatchNorm2d(self._num_hidden) |
| 82 | + self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden - 1, 3, padding=1) |
| 83 | + self._bn2 = nn.BatchNorm2d(self._num_hidden - 1) |
| 84 | + self._conv3 = nn.Conv2d(self._num_hidden - 1, self._num_hidden, 3, padding=1) |
| 85 | + self._bn3 = nn.BatchNorm2d(self._num_hidden) |
| 86 | + |
| 87 | + def forward(self, batch: Tensor) -> Tensor: |
| 88 | + batch = F.relu(self._bn1(self._conv1(batch))) |
| 89 | + batch = F.relu(self._bn2(self._conv2(batch))) |
| 90 | + batch = F.relu(self._bn3(self._conv3(batch))) |
| 91 | + return batch |
| 92 | + |
| 93 | + @property |
| 94 | + def num_hidden(self): |
| 95 | + return self._num_hidden |
| 96 | + |
| 97 | + |
| 98 | +class BNPostBlock(nn.Module): |
| 99 | + def __init__(self, in_channels: int): |
| 100 | + super().__init__() |
| 101 | + self._num_hidden = in_channels * 2 |
| 102 | + self._conv1 = nn.Conv2d(in_channels, self._num_hidden, 3, stride=2, padding=1) |
| 103 | + self._bn1 = nn.BatchNorm2d(self._num_hidden) |
| 104 | + self._conv2 = nn.Conv2d(self._num_hidden, self._num_hidden - 1, 3, padding=1) |
| 105 | + self._bn2 = nn.BatchNorm2d(self._num_hidden - 1) |
| 106 | + self._conv3 = nn.Conv2d(self._num_hidden - 1, self._num_hidden, 3, padding=1) |
| 107 | + self._bn3 = nn.BatchNorm2d(self._num_hidden) |
| 108 | + |
| 109 | + def forward(self, batch: Tensor) -> Tensor: |
| 110 | + batch = self._bn1(F.relu(self._conv1(batch))) |
| 111 | + batch = self._bn2(F.relu(self._conv2(batch))) |
| 112 | + batch = self._bn3(F.relu(self._conv3(batch))) |
| 113 | + return batch |
| 114 | + |
| 115 | + @property |
| 116 | + def num_hidden(self): |
| 117 | + return self._num_hidden |
0 commit comments