-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: Add interfaces for batch materialization engine #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
6cc7e75
feat: Add scaffolding for batch materialization engine
achals 1656436
fix tests
achals c16734b
fix tests
achals ceedbab
a little better
achals 37fa081
a little better
achals 795e65a
docs
achals 5833556
more api updates'
achals 5d1af33
fix typos
achals ff680a9
make engine importable
achals 11b6da0
style stuff
achals 2e9e7a8
style stuff
achals File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| from .batch_materialization_engine import ( | ||
| BatchMaterializationEngine, | ||
| MaterializationJob, | ||
| MaterializationTask, | ||
| ) | ||
| from .local_engine import LocalMaterializationEngine | ||
|
|
||
| __all__ = [ | ||
| "MaterializationJob", | ||
| "MaterializationTask", | ||
| "BatchMaterializationEngine", | ||
| "LocalMaterializationEngine", | ||
| ] | ||
61 changes: 61 additions & 0 deletions
61
sdk/python/feast/infra/materialization/batch_materialization_engine.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| import dataclasses | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from abc import ABC, abstractmethod | ||
| from datetime import datetime | ||
| from typing import Callable, List, Optional, Union | ||
|
|
||
| from tqdm import tqdm | ||
|
|
||
| from feast.batch_feature_view import BatchFeatureView | ||
| from feast.infra.offline_stores.offline_store import OfflineStore | ||
| from feast.infra.online_stores.online_store import OnlineStore | ||
| from feast.repo_config import RepoConfig | ||
| from feast.stream_feature_view import StreamFeatureView | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| class MaterializationTask: | ||
| project: str | ||
| feature_view: Union[BatchFeatureView, StreamFeatureView] | ||
| start_time: datetime | ||
| end_time: datetime | ||
| tqdm_builder: Callable[[int], tqdm] | ||
|
|
||
|
|
||
| class MaterializationJob(ABC): | ||
| task: MaterializationTask | ||
|
|
||
| @abstractmethod | ||
| def status(self) -> str: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def should_be_retried(self) -> bool: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def job_id(self) -> str: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def url(self) -> Optional[str]: | ||
| ... | ||
|
|
||
|
|
||
| class BatchMaterializationEngine(ABC): | ||
| def __init__( | ||
| self, | ||
| *, | ||
| repo_config: RepoConfig, | ||
| offline_store: OfflineStore, | ||
| online_store: OnlineStore, | ||
| **kwargs, | ||
| ): | ||
| self.repo_config = repo_config | ||
| self.offline_store = offline_store | ||
| self.online_store = online_store | ||
|
|
||
| @abstractmethod | ||
| def materialize( | ||
| self, registry, tasks: List[MaterializationTask] | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> List[MaterializationJob]: | ||
| ... | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,302 @@ | ||
| from datetime import datetime | ||
| from typing import Callable, Dict, List, Literal, Optional, Tuple, Union | ||
|
|
||
| import dask.dataframe as dd | ||
| import pandas as pd | ||
| import pyarrow as pa | ||
| from tqdm import tqdm | ||
|
|
||
| from feast import ( | ||
| BatchFeatureView, | ||
| Entity, | ||
| FeatureView, | ||
| RepoConfig, | ||
| StreamFeatureView, | ||
| ValueType, | ||
| ) | ||
| from feast.feature_view import DUMMY_ENTITY_ID | ||
| from feast.infra.offline_stores.offline_store import OfflineStore | ||
| from feast.infra.online_stores.online_store import OnlineStore | ||
| from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto | ||
| from feast.protos.feast.types.Value_pb2 import Value as ValueProto | ||
| from feast.repo_config import FeastConfigBaseModel | ||
| from feast.type_map import python_values_to_proto_values | ||
|
|
||
| from .batch_materialization_engine import ( | ||
| BatchMaterializationEngine, | ||
| MaterializationJob, | ||
| MaterializationTask, | ||
| ) | ||
|
|
||
| DEFAULT_BATCH_SIZE = 10_000 | ||
|
|
||
|
|
||
| class LocalMaterializationEngineConfig(FeastConfigBaseModel): | ||
| """Batch Materialization Engine config for local in-process engine""" | ||
|
|
||
| type: Literal["local"] = "local" | ||
| """ Type selector""" | ||
|
|
||
|
|
||
| class LocalMaterializationJob(MaterializationJob): | ||
| def __init__(self, job_id: str) -> None: | ||
| super().__init__() | ||
| self._job_id: str = job_id | ||
achals marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def status(self) -> str: | ||
| return "success" | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def should_be_retried(self) -> bool: | ||
| return False | ||
|
|
||
| def job_id(self) -> str: | ||
| return self.job_id() | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def url(self) -> Optional[str]: | ||
| return None | ||
|
|
||
|
|
||
| class LocalMaterializationEngine(BatchMaterializationEngine): | ||
| def __init__( | ||
| self, | ||
| *, | ||
| repo_config: RepoConfig, | ||
| offline_store: OfflineStore, | ||
| online_store: OnlineStore, | ||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
| repo_config=repo_config, | ||
| offline_store=offline_store, | ||
| online_store=online_store, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| def materialize( | ||
| self, registry, tasks: List[MaterializationTask] | ||
| ) -> List[MaterializationJob]: | ||
| return [ | ||
| self.materialize_one( | ||
| registry, | ||
| task.feature_view, | ||
| task.start_time, | ||
| task.end_time, | ||
| task.project, | ||
| task.tqdm_builder, | ||
| ) | ||
| for task in tasks | ||
| ] | ||
|
|
||
| def materialize_one( | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self, | ||
| registry, | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| feature_view: Union[BatchFeatureView, StreamFeatureView], | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| start_date: datetime, | ||
| end_date: datetime, | ||
| project: str, | ||
| tqdm_builder: Callable[[int], tqdm], | ||
| ): | ||
| entities = [] | ||
| for entity_name in feature_view.entities: | ||
| entities.append(registry.get_entity(entity_name, project)) | ||
|
|
||
| ( | ||
| join_key_columns, | ||
| feature_name_columns, | ||
| timestamp_field, | ||
| created_timestamp_column, | ||
| ) = _get_column_names(feature_view, entities) | ||
|
|
||
| offline_job = self.offline_store.pull_latest_from_table_or_query( | ||
| config=self.repo_config, | ||
| data_source=feature_view.batch_source, | ||
| join_key_columns=join_key_columns, | ||
| feature_name_columns=feature_name_columns, | ||
| timestamp_field=timestamp_field, | ||
| created_timestamp_column=created_timestamp_column, | ||
| start_date=start_date, | ||
| end_date=end_date, | ||
| ) | ||
|
|
||
| table = offline_job.to_arrow() | ||
|
|
||
| if feature_view.batch_source.field_mapping is not None: | ||
| table = _run_field_mapping(table, feature_view.batch_source.field_mapping) | ||
|
|
||
| join_key_to_value_type = { | ||
| entity.name: entity.dtype.to_value_type() | ||
| for entity in feature_view.entity_columns | ||
| } | ||
|
|
||
| with tqdm_builder(table.num_rows) as pbar: | ||
| for batch in table.to_batches(DEFAULT_BATCH_SIZE): | ||
| rows_to_write = _convert_arrow_to_proto( | ||
| batch, feature_view, join_key_to_value_type | ||
| ) | ||
| self.online_store.online_write_batch( | ||
| self.repo_config, | ||
| feature_view, | ||
| rows_to_write, | ||
| lambda x: pbar.update(x), | ||
| ) | ||
| job_id = f"{feature_view.name}-{start_date}-{end_date}" | ||
| return LocalMaterializationJob(job_id=job_id) | ||
|
|
||
|
|
||
| def _get_column_names( | ||
| feature_view: FeatureView, entities: List[Entity] | ||
| ) -> Tuple[List[str], List[str], str, Optional[str]]: | ||
| """ | ||
| If a field mapping exists, run it in reverse on the join keys, | ||
| feature names, event timestamp column, and created timestamp column | ||
| to get the names of the relevant columns in the offline feature store table. | ||
|
|
||
| Returns: | ||
| Tuple containing the list of reverse-mapped join_keys, | ||
| reverse-mapped feature names, reverse-mapped event timestamp column, | ||
| and reverse-mapped created timestamp column that will be passed into | ||
| the query to the offline store. | ||
| """ | ||
| # if we have mapped fields, use the original field names in the call to the offline store | ||
| timestamp_field = feature_view.batch_source.timestamp_field | ||
| feature_names = [feature.name for feature in feature_view.features] | ||
| created_timestamp_column = feature_view.batch_source.created_timestamp_column | ||
| join_keys = [ | ||
| entity.join_key for entity in entities if entity.join_key != DUMMY_ENTITY_ID | ||
| ] | ||
| if feature_view.batch_source.field_mapping is not None: | ||
| reverse_field_mapping = { | ||
| v: k for k, v in feature_view.batch_source.field_mapping.items() | ||
| } | ||
| timestamp_field = ( | ||
| reverse_field_mapping[timestamp_field] | ||
| if timestamp_field in reverse_field_mapping.keys() | ||
| else timestamp_field | ||
| ) | ||
| created_timestamp_column = ( | ||
| reverse_field_mapping[created_timestamp_column] | ||
| if created_timestamp_column | ||
| and created_timestamp_column in reverse_field_mapping.keys() | ||
| else created_timestamp_column | ||
| ) | ||
| join_keys = [ | ||
| reverse_field_mapping[col] if col in reverse_field_mapping.keys() else col | ||
| for col in join_keys | ||
| ] | ||
| feature_names = [ | ||
| reverse_field_mapping[col] if col in reverse_field_mapping.keys() else col | ||
| for col in feature_names | ||
| ] | ||
|
|
||
| # We need to exclude join keys and timestamp columns from the list of features, after they are mapped to | ||
| # their final column names via the `field_mapping` field of the source. | ||
| feature_names = [ | ||
| name | ||
| for name in feature_names | ||
| if name not in join_keys | ||
| and name != timestamp_field | ||
| and name != created_timestamp_column | ||
| ] | ||
| return ( | ||
| join_keys, | ||
| feature_names, | ||
| timestamp_field, | ||
| created_timestamp_column, | ||
| ) | ||
|
|
||
|
|
||
| def _run_field_mapping(table: pa.Table, field_mapping: Dict[str, str],) -> pa.Table: | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # run field mapping in the forward direction | ||
| cols = table.column_names | ||
| mapped_cols = [ | ||
| field_mapping[col] if col in field_mapping.keys() else col for col in cols | ||
| ] | ||
| table = table.rename_columns(mapped_cols) | ||
| return table | ||
|
|
||
|
|
||
| def _run_dask_field_mapping( | ||
| table: dd.DataFrame, field_mapping: Dict[str, str], | ||
| ): | ||
| if field_mapping: | ||
| # run field mapping in the forward direction | ||
| table = table.rename(columns=field_mapping) | ||
| table = table.persist() | ||
|
|
||
| return table | ||
|
|
||
|
|
||
| def _coerce_datetime(ts): | ||
| """ | ||
| Depending on underlying time resolution, arrow to_pydict() sometimes returns pd | ||
| timestamp type (for nanosecond resolution), and sometimes you get standard python datetime | ||
| (for microsecond resolution). | ||
| While pd timestamp class is a subclass of python datetime, it doesn't always behave the | ||
| same way. We convert it to normal datetime so that consumers downstream don't have to deal | ||
| with these quirks. | ||
| """ | ||
| if isinstance(ts, pd.Timestamp): | ||
| return ts.to_pydatetime() | ||
| else: | ||
| return ts | ||
|
|
||
|
|
||
| def _convert_arrow_to_proto( | ||
| table: Union[pa.Table, pa.RecordBatch], | ||
| feature_view: FeatureView, | ||
| join_keys: Dict[str, ValueType], | ||
| ) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]: | ||
| # Avoid ChunkedArrays which guarentees `zero_copy_only` availiable. | ||
achals marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if isinstance(table, pa.Table): | ||
| table = table.to_batches()[0] | ||
|
|
||
| columns = [ | ||
| (field.name, field.dtype.to_value_type()) for field in feature_view.features | ||
| ] + list(join_keys.items()) | ||
|
|
||
| proto_values_by_column = { | ||
| column: python_values_to_proto_values( | ||
| table.column(column).to_numpy(zero_copy_only=False), value_type | ||
| ) | ||
| for column, value_type in columns | ||
| } | ||
|
|
||
| entity_keys = [ | ||
| EntityKeyProto( | ||
| join_keys=join_keys, | ||
| entity_values=[proto_values_by_column[k][idx] for k in join_keys], | ||
| ) | ||
| for idx in range(table.num_rows) | ||
| ] | ||
|
|
||
| # Serialize the features per row | ||
| feature_dict = { | ||
| feature.name: proto_values_by_column[feature.name] | ||
| for feature in feature_view.features | ||
| } | ||
| features = [dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values())] | ||
|
|
||
| # Convert event_timestamps | ||
| event_timestamps = [ | ||
| _coerce_datetime(val) | ||
| for val in pd.to_datetime( | ||
| table.column(feature_view.batch_source.timestamp_field).to_numpy( | ||
| zero_copy_only=False | ||
| ) | ||
| ) | ||
| ] | ||
|
|
||
| # Convert created_timestamps if they exist | ||
| if feature_view.batch_source.created_timestamp_column: | ||
| created_timestamps = [ | ||
| _coerce_datetime(val) | ||
| for val in pd.to_datetime( | ||
| table.column( | ||
| feature_view.batch_source.created_timestamp_column | ||
| ).to_numpy(zero_copy_only=False) | ||
| ) | ||
| ] | ||
| else: | ||
| created_timestamps = [None] * table.num_rows | ||
|
|
||
| return list(zip(entity_keys, features, event_timestamps, created_timestamps)) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.