Skip to content

Commit c6aec86

Browse files
committed
update logic for mean in lstm
1 parent 2e28021 commit c6aec86

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

nbs/02_lstm.ipynb

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,9 @@
799799
"source": [
800800
"### Get statistics of embeddings (train)\n",
801801
"\n",
802-
"Mean of training data to standardize the embeddings. These are global statistics for the train set."
802+
"Mean of training data to standardize the embeddings. These are global statistics for the train set.\n",
803+
"\n",
804+
"Take mean up to the `:-1` index as this is only seen by model."
803805
]
804806
},
805807
{
@@ -820,10 +822,10 @@
820822
],
821823
"source": [
822824
"window_means = np.asarray(\n",
823-
" [data_embeddings[i][\"subset\"].mean().item() for i in trn_data_idxs]\n",
825+
" [data_embeddings[i][\"subset\"][:-1].mean().item() for i in trn_data_idxs]\n",
824826
")\n",
825827
"window_stds = np.asarray(\n",
826-
" [data_embeddings[i][\"subset\"].std().item() for i in trn_data_idxs]\n",
828+
" [data_embeddings[i][\"subset\"][:-1].std().item() for i in trn_data_idxs]\n",
827829
")\n",
828830
"emb_mean, emb_std = window_means.mean(), window_stds.mean()\n",
829831
"emb_mean, emb_std"
@@ -854,10 +856,10 @@
854856
],
855857
"source": [
856858
"val_window_means = np.asarray(\n",
857-
" [data_embeddings[i][\"subset\"].mean().item() for i in val_data_idxs]\n",
859+
" [data_embeddings[i][\"subset\"][:-1].mean().item() for i in val_data_idxs]\n",
858860
")\n",
859861
"val_window_stds = np.asarray(\n",
860-
" [data_embeddings[i][\"subset\"].std().item() for i in val_data_idxs]\n",
862+
" [data_embeddings[i][\"subset\"][:-1].std().item() for i in val_data_idxs]\n",
861863
")\n",
862864
"val_emb_mean, val_emb_std = val_window_means.mean(), val_window_stds.mean()\n",
863865
"val_emb_mean, val_emb_std"

scripts/train_lstm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def get_embedding_windows(data, vae, cfg):
8080
logging.info(
8181
f"From {len(data_windowed)} windows, {len(embeddings)} Embeddings generated using VAE."
8282
)
83+
8384
return embeddings
8485

8586

@@ -246,10 +247,17 @@ def main(cfg):
246247
np.random.shuffle(trn_data_idxs)
247248
logging.info(f"Train embedding indices: {len(trn_data_idxs)}")
248249
# calculate mean and std of embeddings, should be very close to 0, 1 as sampler of VAE is Normal
249-
window_means = np.asarray([emb[i]["subset"].mean().item() for i in trn_data_idxs])
250-
window_stds = np.asarray([emb[i]["subset"].std().item() for i in trn_data_idxs])
250+
# take mean upto x[:-1] as this is the real training size in each subset
251+
window_means = np.asarray(
252+
[emb[i]["subset"][:-1].mean().item() for i in trn_data_idxs]
253+
)
254+
window_stds = np.asarray(
255+
[emb[i]["subset"][:-1].std().item() for i in trn_data_idxs]
256+
)
257+
# mean of means
251258
emb_mean, emb_std = window_means.mean(), window_stds.mean()
252259
logging.info(f"Embedding mean and std of train: {emb_mean} ({emb_std})")
260+
params.update({"means": emb_mean, "stds": emb_std})
253261

254262
dset_trn = TSLSTMDataset(
255263
emb,

0 commit comments

Comments
 (0)