Source code for pallas.results

# Copyright 2020 Akamai Technologies, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Encapsulation of results returned from Athena.
"""

from __future__ import annotations

from typing import Dict, Mapping, Sequence, TextIO, cast, overload

from pallas._compat import pandas as pd
from pallas.conversions import Converter, get_converter
from pallas.csv import read_csv, write_csv

QueryRecord = Dict[str, object]


[docs]class QueryResults(Sequence[QueryRecord]): """ Collection of Athena query results. Implements a list-like interface for accessing individual records. Alternatively, can be converted to :class:`pandas.DataFrame` using the :meth:`.to_df` method. """ _column_names: Sequence[str] _column_types: Sequence[str] _data: Sequence[Sequence[str | None]] def __init__( self, column_names: Sequence[str], column_types: Sequence[str], data: Sequence[Sequence[str | None]], ) -> None: self._column_names = column_names self._column_types = column_types self._data = data def __repr__(self) -> str: parts = [ f"{len(self)} results", f"column_names={self.column_names!r}", f"column_types={self.column_types!r}", ] return f"<{type(self).__name__}: {', '.join(parts)}>" @overload def __getitem__(self, index: int) -> QueryRecord: ... @overload # noqa: F811 def __getitem__(self, index: slice) -> Sequence[QueryRecord]: ...
[docs] def __getitem__( # noqa: F811 self, index: int | slice ) -> QueryRecord | Sequence[QueryRecord]: """ Return one result or slice of results. Records are returned as mappings from column names to values. """ if isinstance(index, slice): return QueryResults(self.column_names, self.column_types, self._data[index]) row = self._data[index] return { cn: converter.read(v) for cn, converter, v in zip(self.column_names, self._converters, row) }
[docs] def __len__(self) -> int: """ Return count of this results. """ return len(self._data)
[docs] @classmethod def load(cls, stream: TextIO) -> QueryResults: """ Deserialize results from a text stream. """ reader = read_csv(stream) column_names = next(reader) column_types = next(reader) data = list(reader) if any(v is None for v in column_names): raise ValueError("Missing column name") if any(v is None for v in column_types): raise ValueError("Missing column type") column_names = cast(Sequence[str], column_names) column_types = cast(Sequence[str], column_types) return cls(column_names, column_types, data)
[docs] def save(self, stream: TextIO) -> None: """ Serialize results to a text stream. """ write_csv([self._column_names, self._column_types], stream) write_csv(self._data, stream)
@property def column_names(self) -> Sequence[str]: """List of column names.""" return list(self._column_names) @property def column_types(self) -> Sequence[str]: """List of column types.""" return list(self._column_types) @property def _converters(self) -> Sequence[Converter[object]]: return list(map(get_converter, self.column_types))
[docs] def to_df(self, dtypes: Mapping[str, object] | None = None) -> pd.DataFrame: """ Convert this results to :class:`pandas.DataFrame`. """ if pd is None: raise RuntimeError("Pandas cannot be imported.") if dtypes is None: dtypes = {} frame_data = {} for i, (name, converter) in enumerate(zip(self.column_names, self._converters)): values = (row[i] for row in self._data) frame_data[name] = converter.read_array(values, dtype=dtypes.get(name)) return pd.DataFrame(frame_data, copy=False)