Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spanner_orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
not_in_list = condition.not_in_list
not_less_than = condition.not_less_than
order_by = condition.order_by
select_columns = condition.select_columns
ORDER_ASC = condition.OrderType.ASC
ORDER_DESC = condition.OrderType.DESC

Expand Down
41 changes: 41 additions & 0 deletions spanner_orm/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class Segment(enum.Enum):
"""The segment of the SQL query that a Condition belongs to."""

SELECT = 0
FROM = 1
JOIN = 2
WHERE = 3
Expand Down Expand Up @@ -595,6 +596,33 @@ def __init__(self, column: Union[field.Field, str], value: Any):
super().__init__("!=", "IS NOT", column, value)


class SelectColumnsCondition(Condition):
"""Used to indicate which columns should be queried in a Spanner query."""

def __init__(self, columns: List[Union[field.Field, str]]):
super().__init__()
self.columns = [c if isinstance(c, str) else c.name for c in columns]

def _params(self) -> Dict[str, Any]:
return {}

def segment(self) -> Segment:
return Segment.SELECT

def _sql(self) -> str:
pass

def _types(self) -> Dict[str, type_pb2.Type]:
return {}

def _validate(self, model_class: Type[Any]) -> None:
for column in self.columns:
if column not in model_class.columns:
raise error.ValidationError(
f"Invalid column name: {column} not in model class: {model_class}"
)


def columns_equal(
origin_column: str, dest_model_class: Type[Any], dest_column: str
) -> ColumnsEqualCondition:
Expand Down Expand Up @@ -843,3 +871,16 @@ def order_by(*orderings: Tuple[Union[field.Field, str], OrderType]) -> OrderByCo
A Condition subclass that will be used in the query
"""
return OrderByCondition(*orderings)


def select_columns(columns: List[Union[field.Field, str]]) -> SelectColumnsCondition:
"""Condition to limit which columns should be queried. Default is to query all columns.
All the omitted fields will be set to None.

Args:
columns: Column names to query

Returns:
A Condition subclass that will be used in the query
"""
return SelectColumnsCondition(columns)
16 changes: 12 additions & 4 deletions spanner_orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Helps build SQL for complex Spanner queries."""

import abc
from typing import Any, Dict, Iterable, List, Tuple, Type
from typing import Any, Dict, Iterable, List, Tuple, Type, cast

from spanner_orm import condition
from spanner_orm import error
Expand Down Expand Up @@ -162,6 +162,7 @@ class SelectQuery(SpannerQuery):
def __init__(self, model: Type[Any], conditions: Iterable[condition.Condition]):
self._model = model
self._conditions = conditions
self._columns = model.columns
self._joins = self._segments(condition.Segment.JOIN)
self._subqueries = [
_SelectSubQuery(join.destination, join.conditions)
Expand All @@ -175,9 +176,16 @@ def _select_prefix(self) -> str:

def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
parameters, types = {}, {}
selects = self._segments(condition.Segment.SELECT)
if selects:
if len(selects) != 1:
raise error.SpannerError("Only one select column condition may be specified")
select_columns = cast(condition.SelectColumnsCondition, selects[0])
self._columns = select_columns.columns

columns = [
"{alias}.{column}".format(alias=self._model.column_prefix, column=column)
for column in self._model.columns
for column in self._columns
]
for subquery in self._subqueries:
subquery.param_offset = self._next_param_index()
Expand All @@ -197,8 +205,8 @@ def process_results(self, results: List[List[Any]]) -> List[Type[Any]]:

def _process_row(self, row: List[Any]) -> Type[Any]:
"""Parses a row of results from a Spanner query based on the conditions."""
values = dict(zip(self._model.columns, row))
join_values = row[len(self._model.columns) :]
values = dict(zip(self._columns, row))
join_values = row[len(self._columns) :]
for join, subquery, join_value in zip(
self._joins, self._subqueries, join_values
):
Expand Down
14 changes: 14 additions & 0 deletions spanner_orm/tests/query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ def test_query_order_by_with_object(self):
select_query = self.select()
self.assertNotRegex(select_query.sql(), "ORDER BY")

def test_query_select_fields(self):
select_query = self.select(condition.select_columns([models.UnittestModel.int_]))

self.assertEqual(select_query.sql(), "SELECT table.int_ FROM table")
self.assertEmpty(select_query.parameters())
self.assertEmpty(select_query.types())

select_query2 = self.select(condition.select_columns(["int_"]))

self.assertEqual(select_query.sql(), select_query2.sql())

select_query = self.select()
self.assertRegex(select_query.sql(), "table.int_2")

@parameterized.parameters(
("int_", 5, field.Integer.grpc_type()),
("string", "foo", field.String.grpc_type()),
Expand Down