# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- from io import BytesIO from typing import Any, Dict, Generator, IO, Iterable, Optional, Type, Union, TYPE_CHECKING from ._shared.avro.avro_io import DatumReader from ._shared.avro.datafile import DataFileReader if TYPE_CHECKING: from ._models import BlobQueryError class BlobQueryReader(object): # pylint: disable=too-many-instance-attributes """A streaming object to read query results.""" name: str """The name of the blob being quered.""" container: str """The name of the container where the blob is.""" response_headers: Dict[str, Any] """The response_headers of the quick query request.""" record_delimiter: str """The delimiter used to separate lines, or records with the data. The `records` method will return these lines via a generator.""" def __init__( self, name: str = None, # type: ignore [assignment] container: str = None, # type: ignore [assignment] errors: Any = None, record_delimiter: str = '\n', encoding: Optional[str] = None, headers: Dict[str, Any] = None, # type: ignore [assignment] response: Any = None, error_cls: Type["BlobQueryError"] = None, # type: ignore [assignment] ) -> None: self.name = name self.container = container self.response_headers = headers self.record_delimiter = record_delimiter self._size = 0 self._bytes_processed = 0 self._errors = errors self._encoding = encoding self._parsed_results = DataFileReader(QuickQueryStreamer(response), DatumReader()) self._first_result = self._process_record(next(self._parsed_results)) self._error_cls = error_cls def __len__(self): return self._size def _process_record(self, result: Dict[str, Any]) -> Optional[bytes]: self._size = result.get('totalBytes', self._size) self._bytes_processed = result.get('bytesScanned', self._bytes_processed) if 'data' in result: return result.get('data') if 'fatal' in result: error = self._error_cls( error=result['name'], is_fatal=result['fatal'], description=result['description'], position=result['position'] ) if self._errors: self._errors(error) return None def _iter_stream(self) -> Generator[bytes, None, None]: if self._first_result is not None: yield self._first_result for next_result in self._parsed_results: processed_result = self._process_record(next_result) if processed_result is not None: yield processed_result def readall(self) -> Union[bytes, str]: """Return all query results. This operation is blocking until all data is downloaded. If encoding has been configured - this will be used to decode individual records are they are received. :returns: The query results. :rtype: Union[bytes, str] """ stream = BytesIO() self.readinto(stream) data = stream.getvalue() if self._encoding: return data.decode(self._encoding) return data def readinto(self, stream: IO) -> None: """Download the query result to a stream. :param IO stream: The stream to download to. This can be an open file-handle, or any writable stream. :returns: None """ for record in self._iter_stream(): stream.write(record) def records(self) -> Iterable[Union[bytes, str]]: """Returns a record generator for the query result. Records will be returned line by line. If encoding has been configured - this will be used to decode individual records are they are received. :returns: A record generator for the query result. :rtype: Iterable[Union[bytes, str]] """ delimiter = self.record_delimiter.encode('utf-8') for record_chunk in self._iter_stream(): for record in record_chunk.split(delimiter): if self._encoding: yield record.decode(self._encoding) else: yield record class QuickQueryStreamer(object): """ File-like streaming iterator. """ def __init__(self, generator): self.generator = generator self.iterator = iter(generator) self._buf = b"" self._point = 0 self._download_offset = 0 self._buf_start = 0 self.file_length = None def __len__(self): return self.file_length def __iter__(self): return self.iterator @staticmethod def seekable(): return True def __next__(self): next_part = next(self.iterator) self._download_offset += len(next_part) return next_part def tell(self): return self._point def seek(self, offset, whence=0): if whence == 0: self._point = offset elif whence == 1: self._point += offset else: raise ValueError("whence must be 0, or 1") if self._point < 0: # pylint: disable=consider-using-max-builtin self._point = 0 # XXX is this right? def read(self, size): try: # keep reading from the generator until the buffer of this stream has enough data to read while self._point + size > self._download_offset: self._buf += self.__next__() except StopIteration: self.file_length = self._download_offset start_point = self._point # EOF self._point = min(self._point + size, self._download_offset) relative_start = start_point - self._buf_start if relative_start < 0: raise ValueError("Buffer has dumped too much data") relative_end = relative_start + size data = self._buf[relative_start: relative_end] # dump the extra data in buffer # buffer start--------------------16bytes----current read position dumped_size = max(relative_end - 16 - relative_start, 0) self._buf_start += dumped_size self._buf = self._buf[dumped_size:] return data