|
799 | 799 | "source": [
|
800 | 800 | "### Get statistics of embeddings (train)\n",
|
801 | 801 | "\n",
|
802 |
| - "Mean of training data to standardize the embeddings. These are global statistics for the train set." |
| 802 | + "Mean of training data to standardize the embeddings. These are global statistics for the train set.\n", |
| 803 | + "\n", |
| 804 | + "Take mean up to the `:-1` index as this is only seen by model." |
803 | 805 | ]
|
804 | 806 | },
|
805 | 807 | {
|
|
820 | 822 | ],
|
821 | 823 | "source": [
|
822 | 824 | "window_means = np.asarray(\n",
|
823 |
| - " [data_embeddings[i][\"subset\"].mean().item() for i in trn_data_idxs]\n", |
| 825 | + " [data_embeddings[i][\"subset\"][:-1].mean().item() for i in trn_data_idxs]\n", |
824 | 826 | ")\n",
|
825 | 827 | "window_stds = np.asarray(\n",
|
826 |
| - " [data_embeddings[i][\"subset\"].std().item() for i in trn_data_idxs]\n", |
| 828 | + " [data_embeddings[i][\"subset\"][:-1].std().item() for i in trn_data_idxs]\n", |
827 | 829 | ")\n",
|
828 | 830 | "emb_mean, emb_std = window_means.mean(), window_stds.mean()\n",
|
829 | 831 | "emb_mean, emb_std"
|
|
854 | 856 | ],
|
855 | 857 | "source": [
|
856 | 858 | "val_window_means = np.asarray(\n",
|
857 |
| - " [data_embeddings[i][\"subset\"].mean().item() for i in val_data_idxs]\n", |
| 859 | + " [data_embeddings[i][\"subset\"][:-1].mean().item() for i in val_data_idxs]\n", |
858 | 860 | ")\n",
|
859 | 861 | "val_window_stds = np.asarray(\n",
|
860 |
| - " [data_embeddings[i][\"subset\"].std().item() for i in val_data_idxs]\n", |
| 862 | + " [data_embeddings[i][\"subset\"][:-1].std().item() for i in val_data_idxs]\n", |
861 | 863 | ")\n",
|
862 | 864 | "val_emb_mean, val_emb_std = val_window_means.mean(), val_window_stds.mean()\n",
|
863 | 865 | "val_emb_mean, val_emb_std"
|
|
0 commit comments