Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spanner_orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
not_less_than = condition.not_less_than
order_by = condition.order_by
select_columns = condition.select_columns
raw_field = condition.raw_field
ORDER_ASC = condition.OrderType.ASC
ORDER_DESC = condition.OrderType.DESC

Expand Down
43 changes: 43 additions & 0 deletions spanner_orm/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,33 @@ def _validate(self, model_class: Type[Any]) -> None:
)


class RawFieldCondition(Condition):
"""Used to add additional raw fields in a Spanner query."""

def __init__(self, expr: str, field_name: str):
super().__init__()
self.expr = expr
self.field = field_name

def _params(self) -> Dict[str, Any]:
return {}

def segment(self) -> Segment:
return Segment.SELECT

def _sql(self) -> str:
return f"{self.expr} AS {self.field}"

def _types(self) -> Dict[str, type_pb2.Type]:
return {}

def _validate(self, model_class: Type[Any]) -> None:
if self.field in model_class.columns:
raise error.ValidationError(
f"Invalid alias name: {self.field} already in in model class: {model_class}"
)


def columns_equal(
origin_column: str, dest_model_class: Type[Any], dest_column: str
) -> ColumnsEqualCondition:
Expand Down Expand Up @@ -884,3 +911,19 @@ def select_columns(columns: List[Union[field.Field, str]]) -> SelectColumnsCondi
A Condition subclass that will be used in the query
"""
return SelectColumnsCondition(columns)


def raw_field(expr: str, field_name: str) -> RawFieldCondition:
"""Condition to include additional raw fields in objects. This may be used to add
Spanner functions and use their outputs. For example, expr = SUBSTR(s, 0, 2), alias = sub_s
will add an additional attribute on the object named `sub_s` which will be a substring
as evaluated by Spanner.

Args:
expr: Expression which can be evaluated by Spanner. May use column names.
field_name: Name of the field

Returns:
A Condition subclass that will be used in the query
"""
return RawFieldCondition(expr, field_name)
4 changes: 4 additions & 0 deletions spanner_orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,10 @@ def __init__(self, values: Dict[str, Any], persisted: bool = False):
if relation in values:
self.__dict__[relation] = values[relation]

for k, v in values.items():
if k not in self._columns and k not in self._relations:
self.__dict__[k] = v

def __setattr__(self, name: str, value: Any) -> None:
if name in self._relations:
raise AttributeError(name)
Expand Down
37 changes: 25 additions & 12 deletions spanner_orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
"""Helps build SQL for complex Spanner queries."""

import abc
from typing import Any, Dict, Iterable, List, Tuple, Type, cast
from typing import Any, Dict, Iterable, List, Tuple, Type, cast, TYPE_CHECKING

from spanner_orm import condition
from spanner_orm import error

if TYPE_CHECKING:
from spanner_orm import Model


class SpannerQuery(abc.ABC):
"""Helps build SQL for complex Spanner queries."""

def __init__(self, model: Type[Any], conditions: Iterable[condition.Condition]):
def __init__(self, model: Type["Model"], conditions: Iterable[condition.Condition]):
self.param_offset = 0
self._model = model
self._conditions = conditions
Expand Down Expand Up @@ -140,7 +143,7 @@ def _limit(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
class CountQuery(SpannerQuery):
"""Handles COUNT Spanner queries."""

def __init__(self, model: Type[Any], conditions: Iterable[condition.Condition]):
def __init__(self, model: Type["Model"], conditions: Iterable[condition.Condition]):
super().__init__(model, conditions)
for c in conditions:
if c.segment() not in [condition.Segment.WHERE, condition.Segment.FROM]:
Expand All @@ -159,10 +162,11 @@ def process_results(self, results: List[List[Any]]) -> int:
class SelectQuery(SpannerQuery):
"""Handles SELECT Spanner queries."""

def __init__(self, model: Type[Any], conditions: Iterable[condition.Condition]):
def __init__(self, model: Type["Model"], conditions: Iterable[condition.Condition]):
self._model = model
self._conditions = conditions
self._columns = model.columns
self._additional_fields = []
self._joins = self._segments(condition.Segment.JOIN)
self._subqueries = [
_SelectSubQuery(join.destination, join.conditions)
Expand All @@ -177,18 +181,27 @@ def _select_prefix(self) -> str:
def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
parameters, types = {}, {}
selects = self._segments(condition.Segment.SELECT)
if selects:
if len(selects) != 1:
select_columns = [
select
for select in selects
if isinstance(select, condition.SelectColumnsCondition)
]
if select_columns:
if len(select_columns) != 1:
raise error.SpannerError(
"Only one select column condition may be specified"
)
select_columns = cast(condition.SelectColumnsCondition, selects[0])
select_columns = cast(condition.SelectColumnsCondition, select_columns[0])
self._columns = select_columns.columns

columns = [
"{alias}.{column}".format(alias=self._model.column_prefix, column=column)
for column in self._columns
columns = [f"{self._model.column_prefix}.{column}" for column in self._columns]
raw_columns = [
select
for select in selects
if isinstance(select, condition.RawFieldCondition)
]
columns += [raw.sql() for raw in raw_columns]
self._additional_fields = [raw.field for raw in raw_columns]
for subquery in self._subqueries:
subquery.param_offset = self._next_param_index()
columns.append("ARRAY({subquery})".format(subquery=subquery.sql()))
Expand All @@ -207,8 +220,8 @@ def process_results(self, results: List[List[Any]]) -> List[Type[Any]]:

def _process_row(self, row: List[Any]) -> Type[Any]:
"""Parses a row of results from a Spanner query based on the conditions."""
values = dict(zip(self._columns, row))
join_values = row[len(self._columns) :]
values = dict(zip(self._columns + self._additional_fields, row))
join_values = row[len(self._columns + self._additional_fields) :]
for join, subquery, join_value in zip(
self._joins, self._subqueries, join_values
):
Expand Down
6 changes: 6 additions & 0 deletions spanner_orm/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def test_delete_deletes(self, delete):
self.assertEqual(table, models.SmallTestModel.table)
self.assertEqual(keyset.keys, [[model.key]])

def test_create_raw_fields(self):
test_model = models.SmallTestModel(
{"key": "key", "value_1": "value", "extra_field": "value"}
)
self.assertEqual(test_model.extra_field, "value")


if __name__ == "__main__":
logging.basicConfig()
Expand Down
17 changes: 17 additions & 0 deletions spanner_orm/tests/query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,23 @@ def test_or(self):
{"int_0": field.Integer.grpc_type(), "int_1": field.Integer.grpc_type()},
)

def test_raw(self):
select_query = self.select(condition.raw_field("COS(float_)", "cosine"))
expected_sql = r"COS\(float_\) AS cosine"
self.assertRegex(select_query.sql(), expected_sql)

select_query = self.select(
condition.select_columns([models.UnittestModel.int_]),
condition.raw_field("CAST(MOD(int_, 2) AS BOOL)", "is_odd"),
)

self.assertEqual(
select_query.sql(),
"SELECT table.int_, CAST(MOD(int_, 2) AS BOOL) AS is_odd FROM table",
)
self.assertEmpty(select_query.parameters())
self.assertEmpty(select_query.types())


if __name__ == "__main__":
logging.basicConfig()
Expand Down