Skip to content

Commit 95ac34a

Browse files
fix: Fixing transformations on writes (feast-dev#5127)
1 parent bebd7be commit 95ac34a

14 files changed

+624
-144
lines changed

sdk/python/feast/feature_store.py

+69-2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@
8989
from feast.saved_dataset import SavedDataset, SavedDatasetStorage, ValidationReference
9090
from feast.ssl_ca_trust_store_setup import configure_ca_trust_store_env_variables
9191
from feast.stream_feature_view import StreamFeatureView
92+
from feast.transformation.pandas_transformation import PandasTransformation
93+
from feast.transformation.python_transformation import PythonTransformation
9294
from feast.utils import _utc_now
9395

9496
warnings.simplefilter("once", DeprecationWarning)
@@ -1546,6 +1548,64 @@ def _get_feature_view_and_df_for_online_write(
15461548
df = pd.DataFrame(df)
15471549
except Exception as _:
15481550
raise DataFrameSerializationError(df)
1551+
1552+
# # Apply transformations if this is an OnDemandFeatureView with write_to_online_store=True
1553+
if (
1554+
isinstance(feature_view, OnDemandFeatureView)
1555+
and feature_view.write_to_online_store
1556+
):
1557+
if (
1558+
feature_view.mode == "python"
1559+
and isinstance(
1560+
feature_view.feature_transformation, PythonTransformation
1561+
)
1562+
and df is not None
1563+
):
1564+
input_dict = (
1565+
df.to_dict(orient="records")[0]
1566+
if feature_view.singleton
1567+
else df.to_dict(orient="list")
1568+
)
1569+
transformed_data = feature_view.feature_transformation.udf(input_dict)
1570+
if feature_view.write_to_online_store:
1571+
entities = [
1572+
self.get_entity(entity)
1573+
for entity in (feature_view.entities or [])
1574+
]
1575+
join_keys = [entity.join_key for entity in entities if entity]
1576+
join_keys = [k for k in join_keys if k in input_dict.keys()]
1577+
transformed_df = pd.DataFrame(transformed_data)
1578+
input_df = pd.DataFrame(input_dict)
1579+
if input_df.shape[0] == transformed_df.shape[0]:
1580+
for k in input_dict:
1581+
if k not in transformed_data:
1582+
transformed_data[k] = input_dict[k]
1583+
transformed_df = pd.DataFrame(transformed_data)
1584+
else:
1585+
transformed_df = pd.merge(
1586+
transformed_df,
1587+
input_df,
1588+
how="left",
1589+
on=join_keys,
1590+
)
1591+
else:
1592+
# overwrite any transformed features and update the dictionary
1593+
for k in input_dict:
1594+
if k not in transformed_data:
1595+
transformed_data[k] = input_dict[k]
1596+
df = pd.DataFrame(transformed_data)
1597+
elif feature_view.mode == "pandas" and isinstance(
1598+
feature_view.feature_transformation, PandasTransformation
1599+
):
1600+
transformed_df = feature_view.feature_transformation.udf(df)
1601+
if df is not None:
1602+
for col in df.columns:
1603+
transformed_df[col] = df[col]
1604+
df = transformed_df
1605+
1606+
else:
1607+
raise Exception("Unsupported OnDemandFeatureView mode")
1608+
15491609
return feature_view, df
15501610

15511611
def write_to_online_store(
@@ -1887,7 +1947,7 @@ def retrieve_online_documents_v2(
18871947

18881948
(
18891949
available_feature_views,
1890-
_,
1950+
available_odfv_views,
18911951
) = utils._get_feature_views_to_use(
18921952
registry=self._registry,
18931953
project=self.project,
@@ -1898,13 +1958,20 @@ def retrieve_online_documents_v2(
18981958
feature_view_set = set()
18991959
for feature in features:
19001960
feature_view_name = feature.split(":")[0]
1901-
feature_view = self.get_feature_view(feature_view_name)
1961+
if feature_view_name in [fv.name for fv in available_odfv_views]:
1962+
feature_view: Union[OnDemandFeatureView, FeatureView] = (
1963+
self.get_on_demand_feature_view(feature_view_name)
1964+
)
1965+
else:
1966+
feature_view = self.get_feature_view(feature_view_name)
19021967
feature_view_set.add(feature_view.name)
19031968
if len(feature_view_set) > 1:
19041969
raise ValueError("Document retrieval only supports a single feature view.")
19051970
requested_features = [
19061971
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
19071972
]
1973+
if len(available_feature_views) == 0:
1974+
available_feature_views.extend(available_odfv_views) # type: ignore[arg-type]
19081975

19091976
requested_feature_view = available_feature_views[0]
19101977
if not requested_feature_view:

sdk/python/feast/feature_view.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,11 @@ def to_proto(self) -> FeatureViewProto:
348348
if self.stream_source:
349349
stream_source_proto = self.stream_source.to_proto()
350350
stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}"
351-
352351
spec = FeatureViewSpecProto(
353352
name=self.name,
354353
entities=self.entities,
355354
entity_columns=[field.to_proto() for field in self.entity_columns],
356-
features=[field.to_proto() for field in self.features],
355+
features=[feature.to_proto() for feature in self.features],
357356
description=self.description,
358357
tags=self.tags,
359358
owner=self.owner,

sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,14 @@ def _get_or_create_collection(
197197
)
198198
index_params = self.client.prepare_index_params()
199199
for vector_field in schema.fields:
200-
if vector_field.dtype in [
201-
DataType.FLOAT_VECTOR,
202-
DataType.BINARY_VECTOR,
203-
]:
200+
if (
201+
vector_field.dtype
202+
in [
203+
DataType.FLOAT_VECTOR,
204+
DataType.BINARY_VECTOR,
205+
]
206+
and vector_field.name in vector_field_dict
207+
):
204208
metric = vector_field_dict[
205209
vector_field.name
206210
].vector_search_metric

sdk/python/feast/infra/online_stores/sqlite.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ def online_write_batch(
167167
table_name = _table_id(project, table)
168168
for feature_name, val in values.items():
169169
if config.online_store.vector_enabled:
170-
if feature_type_dict[feature_name] in FEAST_VECTOR_TYPES:
170+
if (
171+
feature_type_dict.get(feature_name, None)
172+
in FEAST_VECTOR_TYPES
173+
):
171174
val_bin = serialize_f32(
172175
val.float_list_val.val, config.online_store.vector_len
173176
) # type: ignore
@@ -226,22 +229,22 @@ def online_read(
226229

227230
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
228231

232+
serialized_entity_keys = [
233+
serialize_entity_key(
234+
entity_key,
235+
entity_key_serialization_version=config.entity_key_serialization_version,
236+
)
237+
for entity_key in entity_keys
238+
]
229239
# Fetch all entities in one go
230240
cur.execute(
231241
f"SELECT entity_key, feature_name, value, event_ts "
232242
f"FROM {_table_id(config.project, table)} "
233243
f"WHERE entity_key IN ({','.join('?' * len(entity_keys))}) "
234244
f"ORDER BY entity_key",
235-
[
236-
serialize_entity_key(
237-
entity_key,
238-
entity_key_serialization_version=config.entity_key_serialization_version,
239-
)
240-
for entity_key in entity_keys
241-
],
245+
serialized_entity_keys,
242246
)
243247
rows = cur.fetchall()
244-
245248
rows = {
246249
k: list(group) for k, group in itertools.groupby(rows, key=lambda r: r[0])
247250
}

sdk/python/feast/infra/passthrough_provider.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def materialize_single_feature_view(
449449
def get_historical_features(
450450
self,
451451
config: RepoConfig,
452-
feature_views: List[FeatureView],
452+
feature_views: List[Union[FeatureView, OnDemandFeatureView]],
453453
feature_refs: List[str],
454454
entity_df: Union[pd.DataFrame, str],
455455
registry: BaseRegistry,

sdk/python/feast/infra/provider.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def materialize_single_feature_view(
242242
def get_historical_features(
243243
self,
244244
config: RepoConfig,
245-
feature_views: List[FeatureView],
245+
feature_views: List[Union[FeatureView, OnDemandFeatureView]],
246246
feature_refs: List[str],
247247
entity_df: Union[pd.DataFrame, str],
248248
registry: BaseRegistry,

sdk/python/feast/nlp_test_data.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from datetime import datetime
2+
from typing import Dict
3+
4+
import numpy as np
5+
import pandas as pd
6+
7+
8+
def create_document_chunks_df(
9+
documents: Dict[str, str],
10+
start_date: datetime,
11+
end_date: datetime,
12+
embedding_size: int = 60,
13+
) -> pd.DataFrame:
14+
"""
15+
Example df generated by this function:
16+
17+
| event_timestamp | document_id | chunk_id | chunk_text | embedding | created |
18+
|------------------+-------------+----------+------------------+-----------+------------------|
19+
| 2021-03-17 19:31 | doc_1 | chunk-1 | Hello world | [0.1, ...]| 2021-03-24 19:34 |
20+
| 2021-03-17 19:31 | doc_1 | chunk-2 | How are you? | [0.2, ...]| 2021-03-24 19:34 |
21+
| 2021-03-17 19:31 | doc_2 | chunk-1 | This is a test | [0.3, ...]| 2021-03-24 19:34 |
22+
| 2021-03-17 19:31 | doc_2 | chunk-2 | Document chunk | [0.4, ...]| 2021-03-24 19:34 |
23+
"""
24+
df_hourly = pd.DataFrame(
25+
{
26+
"event_timestamp": [
27+
pd.Timestamp(dt, unit="ms").round("ms")
28+
for dt in pd.date_range(
29+
start=start_date,
30+
end=end_date,
31+
freq="1h",
32+
inclusive="left",
33+
tz="UTC",
34+
)
35+
]
36+
+ [
37+
pd.Timestamp(
38+
year=2021, month=4, day=12, hour=7, minute=0, second=0, tz="UTC"
39+
)
40+
]
41+
}
42+
)
43+
df_all_chunks = pd.DataFrame()
44+
45+
for doc_id, doc_text in documents.items():
46+
chunks = doc_text.split(". ") # Simple chunking by sentence
47+
for chunk_id, chunk_text in enumerate(chunks, start=1):
48+
df_hourly_copy = df_hourly.copy()
49+
df_hourly_copy["document_id"] = doc_id
50+
df_hourly_copy["chunk_id"] = f"chunk-{chunk_id}"
51+
df_hourly_copy["chunk_text"] = chunk_text
52+
df_all_chunks = pd.concat([df_hourly_copy, df_all_chunks])
53+
54+
df_all_chunks.reset_index(drop=True, inplace=True)
55+
rows = df_all_chunks["event_timestamp"].count()
56+
57+
# Generate random embeddings for each chunk
58+
df_all_chunks["embedding"] = [
59+
np.random.rand(embedding_size).tolist() for _ in range(rows)
60+
]
61+
df_all_chunks["created"] = pd.to_datetime(pd.Timestamp.now(tz=None).round("ms"))
62+
63+
# Create duplicate rows that should be filtered by created timestamp
64+
late_row = df_all_chunks[rows // 2 : rows // 2 + 1]
65+
df_all_chunks = pd.concat([df_all_chunks, late_row, late_row], ignore_index=True)
66+
67+
return df_all_chunks

sdk/python/feast/on_demand_feature_view.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,6 @@ def to_proto(self) -> OnDemandFeatureViewProto:
339339
write_to_online_store=self.write_to_online_store,
340340
singleton=self.singleton if self.singleton else False,
341341
)
342-
343342
return OnDemandFeatureViewProto(spec=spec, meta=meta)
344343

345344
@classmethod
@@ -454,6 +453,8 @@ def from_proto(
454453
Field(
455454
name=feature.name,
456455
dtype=from_value_type(ValueType(feature.value_type)),
456+
vector_index=feature.vector_index,
457+
vector_search_metric=feature.vector_search_metric,
457458
)
458459
for feature in on_demand_feature_view_proto.spec.features
459460
],
@@ -640,13 +641,25 @@ def transform_dict(
640641

641642
def infer_features(self) -> None:
642643
random_input = self._construct_random_input(singleton=self.singleton)
643-
inferred_features = self.feature_transformation.infer_features(random_input)
644+
inferred_features = self.feature_transformation.infer_features(
645+
random_input=random_input, singleton=self.singleton
646+
)
644647

645648
if self.features:
646649
missing_features = []
647650
for specified_feature in self.features:
648-
if specified_feature not in inferred_features:
651+
if (
652+
specified_feature not in inferred_features
653+
and "Array" not in specified_feature.dtype.__str__()
654+
):
649655
missing_features.append(specified_feature)
656+
elif "Array" in specified_feature.dtype.__str__():
657+
if specified_feature.name not in [
658+
f.name for f in inferred_features
659+
]:
660+
missing_features.append(specified_feature)
661+
else:
662+
pass
650663
if missing_features:
651664
raise SpecifiedFeaturesNotPresentError(
652665
missing_features, inferred_features, self.name
@@ -738,6 +751,7 @@ def on_demand_feature_view(
738751
owner: str = "",
739752
write_to_online_store: bool = False,
740753
singleton: bool = False,
754+
explode: bool = False,
741755
):
742756
"""
743757
Creates an OnDemandFeatureView object with the given user function as udf.
@@ -759,6 +773,7 @@ def on_demand_feature_view(
759773
the online store for faster retrieval.
760774
singleton (optional): A boolean that indicates whether the transformation is executed on a singleton
761775
(only applicable when mode="python").
776+
explode (optional): A boolean that indicates whether the transformation explodes the input data into multiple rows.
762777
"""
763778

764779
def mainify(obj) -> None:
@@ -778,10 +793,6 @@ def decorator(user_function):
778793
)
779794
transformation = PandasTransformation(user_function, udf_string)
780795
elif mode == "python":
781-
if return_annotation not in (inspect._empty, dict[str, Any]):
782-
raise TypeError(
783-
f"return signature for {user_function} is {return_annotation} but should be dict[str, Any]"
784-
)
785796
transformation = PythonTransformation(user_function, udf_string)
786797
elif mode == "substrait":
787798
from ibis.expr.types.relations import Table

sdk/python/feast/transformation/pandas_transformation.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable
1+
from typing import Any, Callable, Optional
22

33
import dill
44
import pandas as pd
@@ -40,7 +40,9 @@ def transform_singleton(self, input_df: pd.DataFrame) -> pd.DataFrame:
4040
"PandasTransformation does not support singleton transformations."
4141
)
4242

43-
def infer_features(self, random_input: dict[str, list[Any]]) -> list[Field]:
43+
def infer_features(
44+
self, random_input: dict[str, list[Any]], singleton: Optional[bool]
45+
) -> list[Field]:
4446
df = pd.DataFrame.from_dict(random_input)
4547
output_df: pd.DataFrame = self.transform(df)
4648

sdk/python/feast/transformation/python_transformation.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from types import FunctionType
2-
from typing import Any
2+
from typing import Any, Optional
33

44
import dill
55
import pyarrow
@@ -45,7 +45,9 @@ def transform_singleton(self, input_dict: dict) -> dict:
4545
output_dict = self.udf.__call__(input_dict)
4646
return {**input_dict, **output_dict}
4747

48-
def infer_features(self, random_input: dict[str, Any]) -> list[Field]:
48+
def infer_features(
49+
self, random_input: dict[str, Any], singleton: Optional[bool] = False
50+
) -> list[Field]:
4951
output_dict: dict[str, Any] = self.transform(random_input)
5052

5153
fields = []
@@ -58,6 +60,10 @@ def infer_features(self, random_input: dict[str, Any]) -> list[Field]:
5860
)
5961
inferred_type = type(feature_value[0])
6062
inferred_value = feature_value[0]
63+
if singleton:
64+
inferred_value = feature_value
65+
inferred_type = None # type: ignore
66+
6167
else:
6268
inferred_type = type(feature_value)
6369
inferred_value = feature_value
@@ -69,7 +75,7 @@ def infer_features(self, random_input: dict[str, Any]) -> list[Field]:
6975
python_type_to_feast_value_type(
7076
feature_name,
7177
value=inferred_value,
72-
type_name=inferred_type.__name__,
78+
type_name=inferred_type.__name__ if inferred_type else None,
7379
)
7480
),
7581
)

0 commit comments

Comments
 (0)