Skip to content

Commit 8cfdcfa

Browse files
thomaspinderThomas Pindersemihakbayrak
authored
Rmspe test stat (#24)
* RMSPE WIP * Rmspe test stat (#22) * Minor change in README to fix guidance for developers (#18) * Noise transform (#19) * Add noise transformation that apply perturbations on treatment * Formatting * Add docstring * Fix linting * Add tests to check perturbation impact and randomness over timepoints * bump version (#20) * Initial implementation of RMSPE * Add TestResultFrame parent class for test results * Add test for RMSPE * Add doc string * Fix linting * Update src/causal_validation/validation/rmspe.py Co-authored-by: Thomas Pinder <[email protected]> * Fix typo --------- Co-authored-by: Thomas Pinder <[email protected]> --------- Co-authored-by: Thomas Pinder <[email protected]> Co-authored-by: Semih Akbayrak <[email protected]>
1 parent 8b44c71 commit 8cfdcfa

File tree

11 files changed

+512
-34
lines changed

11 files changed

+512
-34
lines changed

docs/examples/placebo_test.ipynb

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,22 @@
207207
"datasets = DatasetContainer([data, complex_data], names=[\"Simple\", \"Complex\"])\n",
208208
"PlaceboTest([model, did_model], datasets).execute().summary()"
209209
]
210+
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": null,
214+
"id": "14",
215+
"metadata": {},
216+
"outputs": [],
217+
"source": []
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": null,
222+
"id": "15",
223+
"metadata": {},
224+
"outputs": [],
225+
"source": []
210226
}
211227
],
212228
"metadata": {

src/causal_validation/data.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Dataset:
2727
yte: Float[np.ndarray, "M 1"]
2828
_start_date: dt.date
2929
counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None
30+
synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None
3031
_name: str = None
3132

3233
def to_df(
@@ -151,7 +152,13 @@ def drop_unit(self, idx: int) -> Dataset:
151152
Xtr = np.delete(self.Xtr, [idx], axis=1)
152153
Xte = np.delete(self.Xte, [idx], axis=1)
153154
return Dataset(
154-
Xtr, Xte, self.ytr, self.yte, self._start_date, self.counterfactual
155+
Xtr,
156+
Xte,
157+
self.ytr,
158+
self.yte,
159+
self._start_date,
160+
self.counterfactual,
161+
self.synthetic,
155162
)
156163

157164
def to_placebo_data(self, to_treat_idx: int) -> Dataset:
@@ -204,4 +211,6 @@ def reassign_treatment(
204211
) -> Dataset:
205212
Xtr = data.Xtr
206213
Xte = data.Xte
207-
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)
214+
return Dataset(
215+
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
216+
)

src/causal_validation/models.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
11
from dataclasses import dataclass
22
import typing as tp
33

4+
from azcausal.core.effect import Effect
45
from azcausal.core.error import Error
56
from azcausal.core.estimator import Estimator
6-
from azcausal.core.result import Result
7+
from azcausal.core.result import Result as _Result
8+
from jaxtyping import Float
79

810
from causal_validation.data import Dataset
11+
from causal_validation.types import NPArray
12+
13+
14+
@dataclass
15+
class Result:
16+
effect: Effect
17+
counterfactual: Float[NPArray, "N 1"]
18+
synthetic: Float[NPArray, "N 1"]
19+
observed: Float[NPArray, "N 1"]
920

1021

1122
@dataclass
1223
class AZCausalWrapper:
1324
model: Estimator
1425
error_estimator: tp.Optional[Error] = None
26+
_az_result: _Result = None
1527

1628
def __post_init__(self):
1729
self._model_name = self.model.__class__.__name__
@@ -21,4 +33,30 @@ def __call__(self, data: Dataset, **kwargs) -> Result:
2133
result = self.model.fit(panel, **kwargs)
2234
if self.error_estimator:
2335
self.model.error(result, self.error_estimator)
24-
return result
36+
self._az_result = result
37+
38+
res = Result(
39+
effect=result.effect,
40+
counterfactual=self.counterfactual,
41+
synthetic=self.synthetic,
42+
observed=self.observed,
43+
)
44+
return res
45+
46+
@property
47+
def counterfactual(self) -> Float[NPArray, "N 1"]:
48+
df = self._az_result.effect.by_time
49+
c_factual = df.loc[:, "CF"].values.reshape(-1, 1)
50+
return c_factual
51+
52+
@property
53+
def synthetic(self) -> Float[NPArray, "N 1"]:
54+
df = self._az_result.effect.by_time
55+
synth_control = df.loc[:, "C"].values.reshape(-1, 1)
56+
return synth_control
57+
58+
@property
59+
def observed(self) -> Float[NPArray, "N 1"]:
60+
df = self._az_result.effect.by_time
61+
treated = df.loc[:, "T"].values.reshape(-1, 1)
62+
return treated

src/causal_validation/transforms/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def apply_values(
7474
ytr = ytr + pre_intervention_vals[:, :1]
7575
Xte = Xte + post_intervention_vals[:, 1:]
7676
yte = yte + post_intervention_vals[:, :1]
77-
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)
77+
return Dataset(
78+
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
79+
)
7880

7981

8082
@dataclass(kw_only=True)
@@ -91,4 +93,6 @@ def apply_values(
9193
ytr = ytr * pre_intervention_vals
9294
Xte = Xte * post_intervention_vals
9395
yte = yte * post_intervention_vals
94-
return Dataset(Xtr, Xte, ytr, yte, data._start_date, data.counterfactual)
96+
return Dataset(
97+
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
98+
)

src/causal_validation/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing as tp
22

3+
import numpy as np
34
from scipy.stats._distn_infrastructure import (
45
rv_continuous,
56
rv_discrete,
@@ -10,3 +11,4 @@
1011
InterventionTypes = tp.Literal["pre-intervention", "post-intervention", "both"]
1112
RandomVariable = tp.Union[rv_continuous, rv_discrete]
1213
Number = tp.Union[float, int]
14+
NPArray = np.ndarray

src/causal_validation/validation/placebo.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,26 @@
99
Column,
1010
DataFrameSchema,
1111
)
12-
from rich import box
1312
from rich.progress import (
1413
Progress,
1514
ProgressBar,
1615
track,
1716
)
18-
from rich.table import Table
1917
from scipy.stats import ttest_1samp
18+
from tqdm import (
19+
tqdm,
20+
trange,
21+
)
2022

2123
from causal_validation.data import (
2224
Dataset,
2325
DatasetContainer,
2426
)
25-
from causal_validation.models import AZCausalWrapper
27+
from causal_validation.models import (
28+
AZCausalWrapper,
29+
Result,
30+
)
31+
from causal_validation.validation.testing import TestResultFrame
2632

2733
PlaceboSchema = DataFrameSchema(
2834
{
@@ -39,13 +45,13 @@
3945

4046

4147
@dataclass
42-
class PlaceboTestResult:
43-
effects: tp.Dict[tp.Tuple[str, str], tp.List[Effect]]
48+
class PlaceboTestResult(TestResultFrame):
49+
effects: tp.Dict[tp.Tuple[str, str], tp.List[Result]]
4450

4551
def _model_to_df(
46-
self, model_name: str, dataset_name: str, effects: tp.List[Effect]
52+
self, model_name: str, dataset_name: str, effects: tp.List[Result]
4753
) -> pd.DataFrame:
48-
_effects = [effect.value for effect in effects]
54+
_effects = [e.effect.percentage().value for e in effects]
4955
_n_effects = len(_effects)
5056
expected_effect = np.mean(_effects)
5157
stddev_effect = np.std(_effects)
@@ -71,21 +77,6 @@ def to_df(self) -> pd.DataFrame:
7177
PlaceboSchema.validate(df)
7278
return df
7379

74-
def summary(self, precision: int = 4) -> Table:
75-
table = Table(show_header=True, box=box.MARKDOWN)
76-
df = self.to_df()
77-
numeric_cols = df.select_dtypes(include=[np.number])
78-
df.loc[:, numeric_cols.columns] = np.round(numeric_cols, decimals=precision)
79-
80-
for column in df.columns:
81-
table.add_column(str(column), style="magenta")
82-
83-
for _, value_list in enumerate(df.values.tolist()):
84-
row = [str(x) for x in value_list]
85-
table.add_row(*row)
86-
87-
return table
88-
8980

9081
@dataclass
9182
class PlaceboTest:
@@ -109,7 +100,7 @@ def execute(self, verbose: bool = True) -> PlaceboTestResult:
109100
datasets = self.dataset_dict
110101
n_datasets = len(datasets)
111102
n_control = sum([d.n_units for d in datasets.values()])
112-
with Progress() as progress:
103+
with Progress(disable=not verbose) as progress:
113104
model_task = progress.add_task(
114105
"[red]Models", total=len(self.models), visible=verbose
115106
)
@@ -130,7 +121,6 @@ def execute(self, verbose: bool = True) -> PlaceboTestResult:
130121
progress.update(unit_task, advance=1)
131122
placebo_data = dataset.to_placebo_data(i)
132123
result = model(placebo_data)
133-
result = result.effect.percentage()
134124
model_result.append(result)
135125
results[(model._model_name, data_name)] = model_result
136126
return PlaceboTestResult(effects=results)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from dataclasses import dataclass
2+
import typing as tp
3+
4+
from jaxtyping import Float
5+
import numpy as np
6+
import pandas as pd
7+
from pandera import (
8+
Check,
9+
Column,
10+
DataFrameSchema,
11+
)
12+
from rich import box
13+
from rich.progress import (
14+
Progress,
15+
ProgressBar,
16+
track,
17+
)
18+
19+
from causal_validation.validation.placebo import PlaceboTest
20+
from causal_validation.validation.testing import (
21+
RMSPETestStatistic,
22+
TestResult,
23+
TestResultFrame,
24+
)
25+
26+
RMSPESchema = DataFrameSchema(
27+
{
28+
"Model": Column(str),
29+
"Dataset": Column(str),
30+
"Test statistic": Column(float, coerce=True),
31+
"p-value": Column(
32+
float,
33+
checks=[
34+
Check.greater_than_or_equal_to(0.0),
35+
Check.less_than_or_equal_to(1.0),
36+
],
37+
coerce=True,
38+
),
39+
}
40+
)
41+
42+
43+
@dataclass
44+
class RMSPETestResult(TestResultFrame):
45+
"""
46+
A subclass of TestResultFrame, RMSPETestResult stores test statistics and p-value
47+
for the treated unit. Test statistics for pseudo treatment units are also stored.
48+
"""
49+
50+
treatment_test_results: tp.Dict[tp.Tuple[str, str], TestResult]
51+
pseudo_treatment_test_statistics: tp.Dict[tp.Tuple[str, str], tp.List[Float]]
52+
53+
def to_df(self) -> pd.DataFrame:
54+
dfs = []
55+
for (model, dataset), test_results in self.treatment_test_results.items():
56+
result = {
57+
"Model": model,
58+
"Dataset": dataset,
59+
"Test statistic": test_results.test_statistic,
60+
"p-value": test_results.p_value,
61+
}
62+
df = pd.DataFrame([result])
63+
dfs.append(df)
64+
df = pd.concat(dfs)
65+
RMSPESchema.validate(df)
66+
return df
67+
68+
69+
@dataclass
70+
class RMSPETest(PlaceboTest):
71+
"""
72+
A subclass of PlaceboTest calculates RMSPE as test statistic for all units.
73+
Given the RMSPE test stats, p-value for actual treatment is calculated.
74+
"""
75+
76+
def execute(self, verbose: bool = True) -> RMSPETestResult:
77+
treatment_results, pseudo_treatment_results = {}, {}
78+
datasets = self.dataset_dict
79+
n_datasets = len(datasets)
80+
n_control = sum([d.n_units for d in datasets.values()])
81+
rmspe = RMSPETestStatistic()
82+
with Progress(disable=not verbose) as progress:
83+
model_task = progress.add_task(
84+
"[red]Models", total=len(self.models), visible=verbose
85+
)
86+
data_task = progress.add_task(
87+
"[blue]Datasets", total=n_datasets, visible=verbose
88+
)
89+
unit_task = progress.add_task(
90+
f"[green]Treatment and Control Units",
91+
total=n_control + 1,
92+
visible=verbose,
93+
)
94+
for data_name, dataset in datasets.items():
95+
progress.update(data_task, advance=1)
96+
for model in self.models:
97+
progress.update(unit_task, advance=1)
98+
treatment_result = model(dataset)
99+
treatment_idx = dataset.ytr.shape[0]
100+
treatment_test_stat = rmspe(
101+
dataset,
102+
treatment_result.counterfactual,
103+
treatment_result.synthetic,
104+
treatment_idx,
105+
)
106+
progress.update(model_task, advance=1)
107+
placebo_test_stats = []
108+
for i in range(dataset.n_units):
109+
progress.update(unit_task, advance=1)
110+
placebo_data = dataset.to_placebo_data(i)
111+
result = model(placebo_data)
112+
placebo_test_stats.append(
113+
rmspe(
114+
placebo_data,
115+
result.counterfactual,
116+
result.synthetic,
117+
treatment_idx,
118+
)
119+
)
120+
pval_idx = 1
121+
for p_stat in placebo_test_stats:
122+
pval_idx += 1 if treatment_test_stat < p_stat else 0
123+
pval = pval_idx / (n_control + 1)
124+
treatment_results[(model._model_name, data_name)] = TestResult(
125+
p_value=pval, test_statistic=treatment_test_stat
126+
)
127+
pseudo_treatment_results[(model._model_name, data_name)] = (
128+
placebo_test_stats
129+
)
130+
return RMSPETestResult(
131+
treatment_test_results=treatment_results,
132+
pseudo_treatment_test_statistics=pseudo_treatment_results,
133+
)

0 commit comments

Comments
 (0)