Skip to content

Commit 716eded

Browse files
authored
Allow disabling scaler for DL models (#3251)
*Issue #, if available:* *Description of changes:* - Currently, the incorrect type hint makes it impossible to set `scaling=None` for TiDE, PatchTST, DLinear and LagTST models. If user sets `scaling=None`, Pydantic will raise the ValidationError. This PR fixes this issue. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup
1 parent 12089c7 commit 716eded

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

src/gluonts/torch/model/d_linear/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
from typing import Tuple
14+
from typing import Optional, Tuple
1515

1616
import torch
1717
from torch import nn
@@ -87,7 +87,7 @@ def __init__(
8787
hidden_dimension: int,
8888
distr_output=StudentTOutput(),
8989
kernel_size: int = 25,
90-
scaling: str = "mean",
90+
scaling: Optional[str] = "mean",
9191
) -> None:
9292
super().__init__()
9393

src/gluonts/torch/model/lag_tst/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
activation: str,
5656
norm_first: bool,
5757
num_encoder_layers: int,
58-
scaling: str,
58+
scaling: Optional[str],
5959
lags_seq: Optional[List[int]] = None,
6060
distr_output=StudentTOutput(),
6161
) -> None:

src/gluonts/torch/model/patch_tst/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
activation: str,
109109
norm_first: bool,
110110
num_encoder_layers: int,
111-
scaling: str,
111+
scaling: Optional[str],
112112
distr_output=StudentTOutput(),
113113
) -> None:
114114
super().__init__()

src/gluonts/torch/model/tide/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
from typing import List, Tuple
14+
from typing import List, Optional, Tuple
1515

1616
import torch
1717
from torch import nn
@@ -243,7 +243,7 @@ def __init__(
243243
num_layers_decoder: int,
244244
layer_norm: bool,
245245
distr_output: Output,
246-
scaling: str,
246+
scaling: Optional[str],
247247
) -> None:
248248
super().__init__()
249249

0 commit comments

Comments
 (0)