diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml new file mode 100644 index 0000000000..af59935321 --- /dev/null +++ b/.github/sync-repo-settings.yaml @@ -0,0 +1,13 @@ +# https://github.com/googleapis/repo-automation-bots/tree/master/packages/sync-repo-settings +# Rules for master branch protection +branchProtectionRules: +# Identifies the protection rule pattern. Name of the branch to be protected. +# Defaults to `master` +- pattern: master + requiredStatusCheckContexts: + - 'Kokoro' + - 'cla/google' + - 'Samples - Lint' + - 'Samples - Python 3.6' + - 'Samples - Python 3.7' + - 'Samples - Python 3.8' diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d7907b0dc..0d8f77c32b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,38 @@ [1]: https://pypi.org/project/google-cloud-spanner/#history +## [3.0.0](https://www.github.com/googleapis/python-spanner/compare/v2.1.0...v3.0.0) (2021-01-15) + + +### ⚠ BREAKING CHANGES + +* convert operations pbs into Operation objects when listing operations (#186) + +### Features + +* add support for instance labels ([#193](https://www.github.com/googleapis/python-spanner/issues/193)) ([ed462b5](https://www.github.com/googleapis/python-spanner/commit/ed462b567a1a33f9105ffb37ba1218f379603614)) +* add support for ssl credentials; add throttled field to UpdateDatabaseDdlMetadata ([#161](https://www.github.com/googleapis/python-spanner/issues/161)) ([2faf01b](https://www.github.com/googleapis/python-spanner/commit/2faf01b135360586ef27c66976646593fd85fd1e)) +* adding missing docstrings for functions & classes ([#188](https://www.github.com/googleapis/python-spanner/issues/188)) ([9788cf8](https://www.github.com/googleapis/python-spanner/commit/9788cf8678d882bd4ccf551f828050cbbb8c8f3a)) +* autocommit sample ([#172](https://www.github.com/googleapis/python-spanner/issues/172)) ([4ef793c](https://www.github.com/googleapis/python-spanner/commit/4ef793c9cd5d6dec6e92faf159665e11d63762ad)) + + +### Bug Fixes + +* convert operations pbs into Operation objects when listing operations ([#186](https://www.github.com/googleapis/python-spanner/issues/186)) ([ed7152a](https://www.github.com/googleapis/python-spanner/commit/ed7152adc37290c63e59865265f36c593d9b8da3)) +* Convert PBs in system test cleanup ([#199](https://www.github.com/googleapis/python-spanner/issues/199)) ([ede4343](https://www.github.com/googleapis/python-spanner/commit/ede4343e518780a4ab13ae83017480d7046464d6)) +* **dbapi:** autocommit enabling fails if no transactions begun ([#177](https://www.github.com/googleapis/python-spanner/issues/177)) ([e981adb](https://www.github.com/googleapis/python-spanner/commit/e981adb3157bb06e4cb466ca81d74d85da976754)) +* **dbapi:** executemany() hiding all the results except the last ([#181](https://www.github.com/googleapis/python-spanner/issues/181)) ([020dc17](https://www.github.com/googleapis/python-spanner/commit/020dc17c823dfb65bfaacace14d2c9f491c97e11)) +* **dbapi:** Spanner protobuf changes causes KeyError's ([#206](https://www.github.com/googleapis/python-spanner/issues/206)) ([f1e21ed](https://www.github.com/googleapis/python-spanner/commit/f1e21edbf37aab93615fd415d61f829d2574916b)) +* remove client side gRPC receive limits ([#192](https://www.github.com/googleapis/python-spanner/issues/192)) ([90effc4](https://www.github.com/googleapis/python-spanner/commit/90effc4d0f4780b7a7c466169f9fc1e45dab8e7f)) +* Rename to fix "Mismatched region tag" check ([#201](https://www.github.com/googleapis/python-spanner/issues/201)) ([c000ec4](https://www.github.com/googleapis/python-spanner/commit/c000ec4d9b306baa0d5e9ed95f23c0273d9adf32)) + + +### Documentation + +* homogenize region tags ([#194](https://www.github.com/googleapis/python-spanner/issues/194)) ([1501022](https://www.github.com/googleapis/python-spanner/commit/1501022239dfa8c20290ca0e0cf6a36e9255732c)) +* homogenize region tags pt 2 ([#202](https://www.github.com/googleapis/python-spanner/issues/202)) ([87789c9](https://www.github.com/googleapis/python-spanner/commit/87789c939990794bfd91f5300bedc449fd74bd7e)) +* update CHANGELOG breaking change comment ([#180](https://www.github.com/googleapis/python-spanner/issues/180)) ([c7b3b9e](https://www.github.com/googleapis/python-spanner/commit/c7b3b9e4be29a199618be9d9ffa1d63a9d0f8de7)) + ## [2.1.0](https://www.github.com/googleapis/python-spanner/compare/v2.0.0...v2.1.0) (2020-11-24) @@ -27,7 +59,7 @@ ### ⚠ BREAKING CHANGES -* migrate to v2.0.0 (#147) +* list_instances, list_databases, list_instance_configs, and list_backups will now return protos rather than the handwritten wrapper (#147) ### Features diff --git a/docs/spanner_admin_database_v1/types.rst b/docs/spanner_admin_database_v1/types.rst index da44c33458..fe6c27778b 100644 --- a/docs/spanner_admin_database_v1/types.rst +++ b/docs/spanner_admin_database_v1/types.rst @@ -3,3 +3,4 @@ Types for Google Cloud Spanner Admin Database v1 API .. automodule:: google.cloud.spanner_admin_database_v1.types :members: + :show-inheritance: diff --git a/docs/spanner_admin_instance_v1/types.rst b/docs/spanner_admin_instance_v1/types.rst index b496dfc681..250cf6bf9b 100644 --- a/docs/spanner_admin_instance_v1/types.rst +++ b/docs/spanner_admin_instance_v1/types.rst @@ -3,3 +3,4 @@ Types for Google Cloud Spanner Admin Instance v1 API .. automodule:: google.cloud.spanner_admin_instance_v1.types :members: + :show-inheritance: diff --git a/docs/spanner_v1/types.rst b/docs/spanner_v1/types.rst index 15b938d7f3..c7ff7e6c71 100644 --- a/docs/spanner_v1/types.rst +++ b/docs/spanner_v1/types.rst @@ -3,3 +3,4 @@ Types for Google Cloud Spanner v1 API .. automodule:: google.cloud.spanner_v1.types :members: + :show-inheritance: diff --git a/google/cloud/spanner_admin_database_v1/proto/spanner_database_admin.proto b/google/cloud/spanner_admin_database_v1/proto/spanner_database_admin.proto index af440c1a36..db6192bc02 100644 --- a/google/cloud/spanner_admin_database_v1/proto/spanner_database_admin.proto +++ b/google/cloud/spanner_admin_database_v1/proto/spanner_database_admin.proto @@ -514,6 +514,11 @@ message UpdateDatabaseDdlMetadata { // succeeded so far, where `commit_timestamps[i]` is the commit // timestamp for the statement `statements[i]`. repeated google.protobuf.Timestamp commit_timestamps = 3; + + // Output only. When true, indicates that the operation is throttled e.g + // due to resource constraints. When resources become available the operation + // will resume and this field will be false again. + bool throttled = 4 [(google.api.field_behavior) = OUTPUT_ONLY]; } // The request for [DropDatabase][google.spanner.admin.database.v1.DatabaseAdmin.DropDatabase]. diff --git a/google/cloud/spanner_admin_database_v1/services/database_admin/transports/__init__.py b/google/cloud/spanner_admin_database_v1/services/database_admin/transports/__init__.py index 348af3f043..00a3ab8549 100644 --- a/google/cloud/spanner_admin_database_v1/services/database_admin/transports/__init__.py +++ b/google/cloud/spanner_admin_database_v1/services/database_admin/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = DatabaseAdminGrpcTransport _transport_registry["grpc_asyncio"] = DatabaseAdminGrpcAsyncIOTransport - __all__ = ( "DatabaseAdminTransport", "DatabaseAdminGrpcTransport", diff --git a/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc.py b/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc.py index 0f8d56f05a..e8a0a6f93d 100644 --- a/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc.py +++ b/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc.py @@ -113,6 +113,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -120,6 +122,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -155,7 +158,12 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" @@ -172,9 +180,14 @@ def __init__( ssl_credentials=ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -198,7 +211,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -245,13 +258,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def list_databases( diff --git a/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc_asyncio.py b/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc_asyncio.py index 45f2e2d9e6..7a83120018 100644 --- a/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc_asyncio.py +++ b/google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc_asyncio.py @@ -158,6 +158,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -165,6 +167,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -200,7 +203,12 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" @@ -217,6 +225,10 @@ def __init__( ssl_credentials=ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -230,6 +242,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -249,13 +262,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def list_databases( diff --git a/google/cloud/spanner_admin_database_v1/types/__init__.py b/google/cloud/spanner_admin_database_v1/types/__init__.py index d02a26ffb5..79b682aab9 100644 --- a/google/cloud/spanner_admin_database_v1/types/__init__.py +++ b/google/cloud/spanner_admin_database_v1/types/__init__.py @@ -47,9 +47,9 @@ RestoreDatabaseRequest, RestoreDatabaseMetadata, OptimizeRestoredDatabaseMetadata, + RestoreSourceType, ) - __all__ = ( "OperationProgress", "Backup", @@ -80,4 +80,5 @@ "RestoreDatabaseRequest", "RestoreDatabaseMetadata", "OptimizeRestoredDatabaseMetadata", + "RestoreSourceType", ) diff --git a/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py b/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py index b2b5939f5b..e99d200906 100644 --- a/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py +++ b/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py @@ -287,6 +287,12 @@ class UpdateDatabaseDdlMetadata(proto.Message): Reports the commit timestamps of all statements that have succeeded so far, where ``commit_timestamps[i]`` is the commit timestamp for the statement ``statements[i]``. + throttled (bool): + Output only. When true, indicates that the + operation is throttled e.g due to resource + constraints. When resources become available the + operation will resume and this field will be + false again. """ database = proto.Field(proto.STRING, number=1) @@ -297,6 +303,8 @@ class UpdateDatabaseDdlMetadata(proto.Message): proto.MESSAGE, number=3, message=timestamp.Timestamp, ) + throttled = proto.Field(proto.BOOL, number=4) + class DropDatabaseRequest(proto.Message): r"""The request for diff --git a/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/__init__.py b/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/__init__.py index 2b8e6a24b6..b18f099ef8 100644 --- a/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/__init__.py +++ b/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = InstanceAdminGrpcTransport _transport_registry["grpc_asyncio"] = InstanceAdminGrpcAsyncIOTransport - __all__ = ( "InstanceAdminTransport", "InstanceAdminGrpcTransport", diff --git a/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc.py b/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc.py index a758bb6ad4..aa827a3b75 100644 --- a/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc.py +++ b/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc.py @@ -126,6 +126,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -133,6 +135,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -168,7 +171,12 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" @@ -185,9 +193,14 @@ def __init__( ssl_credentials=ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -211,7 +224,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -258,13 +271,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def list_instance_configs( diff --git a/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc_asyncio.py b/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc_asyncio.py index 91fb40d1e7..a2d22c56f6 100644 --- a/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc_asyncio.py +++ b/google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc_asyncio.py @@ -171,6 +171,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -178,6 +180,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -213,7 +216,12 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" @@ -230,6 +238,10 @@ def __init__( ssl_credentials=ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -243,6 +255,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -262,13 +275,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def list_instance_configs( diff --git a/google/cloud/spanner_admin_instance_v1/types/__init__.py b/google/cloud/spanner_admin_instance_v1/types/__init__.py index 0f096f84c9..37b771feed 100644 --- a/google/cloud/spanner_admin_instance_v1/types/__init__.py +++ b/google/cloud/spanner_admin_instance_v1/types/__init__.py @@ -32,7 +32,6 @@ UpdateInstanceMetadata, ) - __all__ = ( "ReplicaInfo", "InstanceConfig", diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index a397028287..6438605d3b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -22,6 +22,9 @@ from google.cloud import spanner_v1 as spanner from google.cloud.spanner_v1.session import _get_retry_delay +from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous +from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous +from google.cloud.spanner_dbapi._helpers import parse_insert from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor @@ -82,7 +85,7 @@ def autocommit(self, value): :type value: bool :param value: New autocommit mode state. """ - if value and not self._autocommit: + if value and not self._autocommit and self.inside_transaction: self.commit() self._autocommit = value @@ -96,6 +99,19 @@ def database(self): """ return self._database + @property + def inside_transaction(self): + """Flag: transaction is started. + + Returns: + bool: True if transaction begun, False otherwise. + """ + return ( + self._transaction + and not self._transaction.committed + and not self._transaction.rolled_back + ) + @property def instance(self): """Instance to which this connection relates. @@ -191,11 +207,7 @@ def transaction_checkout(self): :returns: A Cloud Spanner transaction object, ready to use. """ if not self.autocommit: - if ( - not self._transaction - or self._transaction.committed - or self._transaction.rolled_back - ): + if not self.inside_transaction: self._transaction = self._session_checkout().transaction() self._transaction.begin() @@ -216,11 +228,7 @@ def close(self): The connection will be unusable from this point forward. If the connection has an active transaction, it will be rolled back. """ - if ( - self._transaction - and not self._transaction.committed - and not self._transaction.rolled_back - ): + if self.inside_transaction: self._transaction.rollback() if self._own_pool: @@ -235,7 +243,7 @@ def commit(self): """ if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) - elif self._transaction: + elif self.inside_transaction: try: self._transaction.commit() self._release_session() @@ -291,6 +299,24 @@ def run_statement(self, statement, retried=False): if not retried: self._statements.append(statement) + if statement.is_insert: + parts = parse_insert(statement.sql, statement.params) + + if parts.get("homogenous"): + _execute_insert_homogenous(transaction, parts) + return ( + iter(()), + ResultsChecksum() if retried else statement.checksum, + ) + else: + _execute_insert_heterogenous( + transaction, parts.get("sql_params_list"), + ) + return ( + iter(()), + ResultsChecksum() if retried else statement.checksum, + ) + return ( transaction.execute_sql( statement.sql, statement.params, param_types=statement.param_types, diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index e2667f0599..4b5a0d9652 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -37,11 +37,12 @@ from google.cloud.spanner_dbapi.parse_utils import get_param_types from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner from google.cloud.spanner_dbapi.utils import PeekIterator +from google.cloud.spanner_dbapi.utils import StreamedManyResultSets _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) -Statement = namedtuple("Statement", "sql, params, param_types, checksum") +Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert") class Cursor(object): @@ -84,6 +85,9 @@ def description(self): - ``precision`` - ``scale`` - ``null_ok`` + + :rtype: tuple + :returns: A tuple of columns' information. """ if not (self._result_set and self._result_set.metadata): return None @@ -94,11 +98,11 @@ def description(self): for field in row_type.fields: column_info = ColumnInfo( name=field.name, - type_code=field.type.code, + type_code=field.type_.code, # Size of the SQL type of the column. - display_size=code_to_display_size.get(field.type.code), + display_size=code_to_display_size.get(field.type_.code), # Client perceived size of the column. - internal_size=field.ByteSize(), + internal_size=field._pb.ByteSize(), ) columns.append(column_info) @@ -106,7 +110,11 @@ def description(self): @property def rowcount(self): - """The number of rows produced by the last `.execute()`.""" + """The number of rows produced by the last `.execute()`. + + :rtype: int + :returns: The number of rows produced by the last .execute*(). + """ return self._row_count def _raise_if_closed(self): @@ -126,7 +134,14 @@ def callproc(self, procname, args=None): self._raise_if_closed() def close(self): - """Closes this Cursor, making it unusable from this point forward.""" + """Prepare and execute a Spanner database operation. + + :type sql: str + :param sql: A SQL query statement. + + :type args: list + :param args: Additional parameters to supplement the SQL query. + """ self._is_closed = True def _do_execute_update(self, transaction, sql, params, param_types=None): @@ -171,10 +186,20 @@ def execute(self, sql, args=None): self.connection.run_prior_DDL_statements() if not self.connection.autocommit: - sql, params = sql_pyformat_args_to_spanner(sql, args) + if classification == parse_utils.STMT_UPDATING: + sql = parse_utils.ensure_where_clause(sql) + + if classification != parse_utils.STMT_INSERT: + sql, args = sql_pyformat_args_to_spanner(sql, args or None) statement = Statement( - sql, params, get_param_types(params), ResultsChecksum(), + sql, + args, + get_param_types(args or None) + if classification != parse_utils.STMT_INSERT + else {}, + ResultsChecksum(), + classification == parse_utils.STMT_INSERT, ) (self._result_set, self._checksum,) = self.connection.run_statement( statement @@ -210,8 +235,20 @@ def executemany(self, operation, seq_of_params): """ self._raise_if_closed() + classification = parse_utils.classify_stmt(operation) + if classification == parse_utils.STMT_DDL: + raise ProgrammingError( + "Executing DDL statements with executemany() method is not allowed." + ) + + many_result_set = StreamedManyResultSets() + for params in seq_of_params: self.execute(operation, params) + many_result_set.add_iter(self._itr) + + self._result_set = many_result_set + self._itr = many_result_set def fetchone(self): """Fetch the next row of a query result set, returning a single @@ -220,7 +257,8 @@ def fetchone(self): try: res = next(self) - self._checksum.consume_result(res) + if not self.connection.autocommit: + self._checksum.consume_result(res) return res except StopIteration: return @@ -237,7 +275,8 @@ def fetchall(self): res = [] try: for row in self: - self._checksum.consume_result(row) + if not self.connection.autocommit: + self._checksum.consume_result(row) res.append(row) except Aborted: self._connection.retry_transaction() @@ -265,7 +304,8 @@ def fetchmany(self, size=None): for i in range(size): try: res = next(self) - self._checksum.consume_result(res) + if not self.connection.autocommit: + self._checksum.consume_result(res) items.append(res) except StopIteration: break @@ -332,6 +372,11 @@ def __iter__(self): return self._itr def list_tables(self): + """List the tables of the linked Database. + + :rtype: list + :returns: The list of tables within the Database. + """ return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES) def run_sql_in_snapshot(self, sql, params=None, param_types=None): diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 8848233d45..abc36b397c 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -176,11 +176,11 @@ def classify_stmt(query): """Determine SQL query type. - :type query: :class:`str` - :param query: SQL query. + :type query: str + :param query: A SQL query. - :rtype: :class:`str` - :returns: Query type name. + :rtype: str + :returns: The query type name. """ if RE_DDL.match(query): return STMT_DDL @@ -253,6 +253,17 @@ def parse_insert(insert_sql, params): ('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',)) ], } + + :type insert_sql: str + :param insert_sql: A SQL insert request. + + :type params: list + :param params: A list of parameters. + + :rtype: dict + :returns: A dictionary that maps `sql_params_list` to the list of + parameters in cases a), b), d) or the dictionary with information + about the resulting table in case c). """ # noqa match = RE_INSERT.search(insert_sql) @@ -348,8 +359,16 @@ def rows_for_insert_or_update(columns, params, pyformat_args=None): We'll have to convert both params types into: Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - """ # noqa + :type columns: list + :param columns: A list of the columns of the table. + + :type params: list + :param params: A list of parameters. + + :rtype: list + :returns: A properly restructured list of the parameters. + """ # noqa if not pyformat_args: # This is the case where we have for example: # SQL: 'INSERT INTO t (f1, f2, f3)' @@ -445,6 +464,16 @@ def sql_pyformat_args_to_spanner(sql, params): becomes: SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} + + :type sql: str + :param sql: A SQL request. + + :type params: list + :param params: A list of parameters. + + :rtype: tuple(str, dict) + :returns: A tuple of the sanitized SQL and a dictionary of the named + arguments. """ if not params: return sanitize_literals_for_upload(sql), params @@ -488,10 +517,10 @@ def cast_for_spanner(value): """Convert the param to its Cloud Spanner equivalent type. :type value: Any - :param value: Value to convert to a Cloud Spanner type. + :param value: The value to convert to a Cloud Spanner type. :rtype: Any - :returns: Value converted to a Cloud Spanner type. + :returns: The value converted to a Cloud Spanner type. """ if isinstance(value, decimal.Decimal): return str(value) @@ -501,10 +530,10 @@ def cast_for_spanner(value): def get_param_types(params): """Determine Cloud Spanner types for the given parameters. - :type params: :class:`dict` + :type params: dict :param params: Parameters requiring to find Cloud Spanner types. - :rtype: :class:`dict` + :rtype: dict :returns: The types index for the given parameters. """ if params is None: @@ -523,19 +552,15 @@ def get_param_types(params): def ensure_where_clause(sql): """ Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. - Raise an error, if the given sql doesn't include it. + Add a dummy WHERE clause if non detected. - :type sql: `str` + :type sql: str :param sql: SQL code to check. - - :raises: :class:`ProgrammingError` if the given sql doesn't include a WHERE clause. """ if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]): return sql - raise ProgrammingError( - "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" - ) + return sql + " WHERE 1=1" def escape_name(name): @@ -543,10 +568,10 @@ def escape_name(name): Apply backticks to the name that either contain '-' or ' ', or is a Cloud Spanner's reserved keyword. - :type name: :class:`str` + :type name: str :param name: Name to escape. - :rtype: :class:`str` + :rtype: str :returns: Name escaped if it has to be escaped. """ if "-" in name or " " in name or name.upper() in SPANNER_RESERVED_KEYWORDS: diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py index 9271631b25..43e446c58e 100644 --- a/google/cloud/spanner_dbapi/parser.py +++ b/google/cloud/spanner_dbapi/parser.py @@ -68,14 +68,18 @@ def __len__(self): class terminal(str): - """ - terminal represents the unit symbol that can be part of a SQL values clause. - """ + """Represent the unit symbol that can be part of a SQL values clause.""" pass class a_args(object): + """Expression arguments. + + :type argv: list + :param argv: A List of expression arguments. + """ + def __init__(self, argv): self.argv = argv @@ -108,9 +112,11 @@ def __getitem__(self, index): return self.argv[index] def homogenous(self): - """ - Return True if all the arguments are pyformat - args and have the same number of arguments. + """Check arguments of the expression to be homogeneous. + + :rtype: bool + :return: True if all the arguments of the expression are in pyformat + and each has the same length, False otherwise. """ if not self._is_equal_length(): return False @@ -126,8 +132,10 @@ def homogenous(self): return True def _is_equal_length(self): - """ - Return False if all the arguments have the same length. + """Return False if all the arguments have the same length. + + :rtype: bool + :return: False if the sequences of the arguments have the same length. """ if len(self) == 0: return True @@ -141,6 +149,12 @@ def _is_equal_length(self): class values(a_args): + """A wrapper for values. + + :rtype: str + :returns: A string of the values expression in a tree view. + """ + def __str__(self): return "VALUES%s" % super().__str__() @@ -153,6 +167,21 @@ def parse_values(stmt): def expect(word, token): + """Parse the given expression recursively. + + :type word: str + :param word: A string expression. + + :type token: str + :param token: An expression token. + + :rtype: `Tuple(str, Any)` + :returns: A tuple containing the rest of the expression string and the + parse tree for the part of the expression that has already been + parsed. + + :raises :class:`ProgrammingError`: If there is a parsing error. + """ word = word.strip() if token == VALUES: if not word.startswith("VALUES"): @@ -242,5 +271,13 @@ def expect(word, token): def as_values(values_stmt): + """Return the parsed values. + + :type values_stmt: str + :param values_stmt: Raw values. + + :rtype: Any + :returns: A tree of the already parsed expression. + """ _, _values = parse_values(values_stmt) return _values diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py index b0ad3922a5..bfb97346cf 100644 --- a/google/cloud/spanner_dbapi/utils.py +++ b/google/cloud/spanner_dbapi/utils.py @@ -14,14 +14,18 @@ import re +re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") + class PeekIterator: """ - PeekIterator peeks at the first element out of an iterator - for the sake of operations like auto-population of fields on reading - the first element. - If next's result is an instance of list, it'll be converted into a tuple - to conform with DBAPI v2's sequence expectations. + Peek at the first element out of an iterator for the sake of operations + like auto-population of fields on reading the first element. + If next's result is an instance of list, it'll be converted into a tuple to + conform with DBAPI v2's sequence expectations. + + :type source: list + :param source: A list of source for the Iterator. """ def __init__(self, source): @@ -55,10 +59,55 @@ def __iter__(self): return self -re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") +class StreamedManyResultSets: + """Iterator to walk through several `StreamedResultsSet` iterators. + This type of iterator is used by `Cursor.executemany()` + method to iterate through several `StreamedResultsSet` + iterators like they all are merged into single iterator. + """ + + def __init__(self): + self._iterators = [] + self._index = 0 + + def add_iter(self, iterator): + """Add new iterator into this one. + :type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet` + :param iterator: Iterator to merge into this one. + """ + self._iterators.append(iterator) + + def __next__(self): + """Return the next value from the currently streamed iterator. + If the current iterator is streamed to the end, + start to stream the next one. + :rtype: list + :returns: The next result row. + """ + try: + res = next(self._iterators[self._index]) + except StopIteration: + self._index += 1 + res = self.__next__() + except IndexError: + raise StopIteration + + return res + + def __iter__(self): + return self def backtick_unicode(sql): + """Check the SQL to be valid and split it by segments. + + :type sql: str + :param sql: A SQL request. + + :rtype: str + :returns: A SQL parsed by segments in unicode if initial SQL is valid, + initial string otherwise. + """ matches = list(re_UNICODE_POINTS.finditer(sql)) if not matches: return sql @@ -79,11 +128,20 @@ def backtick_unicode(sql): def sanitize_literals_for_upload(s): - """ - Convert literals in s, to be fit for consumption by Cloud Spanner. - 1. Convert %% (escaped percent literals) to %. Percent signs must be escaped when - values like %s are used as SQL parameter placeholders but Spanner's query language - uses placeholders like @a0 and doesn't expect percent signs to be escaped. - 2. Quote words containing non-ASCII, with backticks, for example föö to `föö`. + """Convert literals in s, to be fit for consumption by Cloud Spanner. + + * Convert %% (escaped percent literals) to %. Percent signs must be escaped + when values like %s are used as SQL parameter placeholders but Spanner's + query language uses placeholders like @a0 and doesn't expect percent + signs to be escaped. + * Quote words containing non-ASCII, with backticks, for example föö to + `föö`. + + :type s: str + :param s: A string with literals to escaped for consumption by Cloud + Spanner. + + :rtype: str + :returns: A sanitized string for uploading. """ return backtick_unicode(s.replace("%%", "%")) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index b433f0c7b0..f4cd6ef910 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -289,6 +289,7 @@ def instance( configuration_name=None, display_name=None, node_count=DEFAULT_NODE_COUNT, + labels=None, ): """Factory to create a instance associated with this client. @@ -313,6 +314,9 @@ def instance( :param node_count: (Optional) The number of nodes in the instance's cluster; used to set up the instance's cluster. + :type labels: dict (str -> str) or None + :param labels: (Optional) User-assigned labels for this instance. + :rtype: :class:`~google.cloud.spanner_v1.instance.Instance` :returns: an instance owned by this client. """ @@ -323,6 +327,7 @@ def instance( node_count, display_name, self._emulator_host, + labels, ) def list_instances(self, filter_="", page_size=None): diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index be49dd2d84..b422c57afd 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -14,6 +14,7 @@ """User friendly container for Cloud Spanner Instance.""" +import google.api_core.operation import re from google.cloud.spanner_admin_instance_v1 import Instance as InstancePB @@ -99,6 +100,9 @@ class Instance(object): Cloud Console UI. (Must be between 4 and 30 characters.) If this value is not set in the constructor, will fall back to the instance ID. + + :type labels: dict (str -> str) or None + :param labels: (Optional) User-assigned labels for this instance. """ def __init__( @@ -109,6 +113,7 @@ def __init__( node_count=DEFAULT_NODE_COUNT, display_name=None, emulator_host=None, + labels=None, ): self.instance_id = instance_id self._client = client @@ -116,6 +121,9 @@ def __init__( self.node_count = node_count self.display_name = display_name or instance_id self.emulator_host = emulator_host + if labels is None: + labels = {} + self.labels = labels def _update_from_pb(self, instance_pb): """Refresh self from the server-provided protobuf. @@ -127,6 +135,7 @@ def _update_from_pb(self, instance_pb): self.display_name = instance_pb.display_name self.configuration_name = instance_pb.config self.node_count = instance_pb.node_count + self.labels = instance_pb.labels @classmethod def from_pb(cls, instance_pb, client): @@ -242,6 +251,7 @@ def create(self): config=self.configuration_name, display_name=self.display_name, node_count=self.node_count, + labels=self.labels, ) metadata = _metadata_with_prefix(self.name) @@ -296,7 +306,7 @@ def update(self): .. note:: - Updates the ``display_name`` and ``node_count``. To change those + Updates the ``display_name``, ``node_count`` and ``labels``. To change those values before updating, set them via .. code:: python @@ -316,8 +326,9 @@ def update(self): config=self.configuration_name, display_name=self.display_name, node_count=self.node_count, + labels=self.labels, ) - field_mask = FieldMask(paths=["config", "display_name", "node_count"]) + field_mask = FieldMask(paths=["config", "display_name", "node_count", "labels"]) metadata = _metadata_with_prefix(self.name) future = api.update_instance( @@ -465,7 +476,7 @@ def list_backup_operations(self, filter_="", page_size=None): page_iter = self._client.database_admin_api.list_backup_operations( request=request, metadata=metadata ) - return page_iter + return map(self._item_to_operation, page_iter) def list_database_operations(self, filter_="", page_size=None): """List database operations for the instance. @@ -493,4 +504,18 @@ def list_database_operations(self, filter_="", page_size=None): page_iter = self._client.database_admin_api.list_database_operations( request=request, metadata=metadata ) - return page_iter + return map(self._item_to_operation, page_iter) + + def _item_to_operation(self, operation_pb): + """Convert an operation protobuf to the native object. + :type operation_pb: :class:`~google.longrunning.operations.Operation` + :param operation_pb: An operation returned from the API. + :rtype: :class:`~google.api_core.operation.Operation` + :returns: The next operation in the page. + """ + operations_client = self._client.database_admin_api.transport.operations_client + metadata_type = _type_string_to_type_pb(operation_pb.metadata.type_url) + response_type = _OPERATION_RESPONSE_TYPES[metadata_type] + return google.api_core.operation.from_gapic( + operation_pb, operations_client, response_type, metadata_type=metadata_type + ) diff --git a/google/cloud/spanner_v1/services/spanner/transports/__init__.py b/google/cloud/spanner_v1/services/spanner/transports/__init__.py index 1bf46eb475..2210e30dd8 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/__init__.py +++ b/google/cloud/spanner_v1/services/spanner/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = SpannerGrpcTransport _transport_registry["grpc_asyncio"] = SpannerGrpcAsyncIOTransport - __all__ = ( "SpannerTransport", "SpannerGrpcTransport", diff --git a/google/cloud/spanner_v1/services/spanner/transports/grpc.py b/google/cloud/spanner_v1/services/spanner/transports/grpc.py index 620a971775..d1688acb92 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/grpc.py +++ b/google/cloud/spanner_v1/services/spanner/transports/grpc.py @@ -106,6 +106,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -113,6 +115,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -148,7 +151,12 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" @@ -165,6 +173,10 @@ def __init__( ssl_credentials=ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] @@ -191,7 +203,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If diff --git a/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py b/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py index 79ab4a1f94..422c51ef6f 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py +++ b/google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py @@ -151,6 +151,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -158,6 +160,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -193,7 +196,12 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" @@ -210,6 +218,10 @@ def __init__( ssl_credentials=ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. diff --git a/google/cloud/spanner_v1/types/__init__.py b/google/cloud/spanner_v1/types/__init__.py index 890a024f01..a71a15855c 100644 --- a/google/cloud/spanner_v1/types/__init__.py +++ b/google/cloud/spanner_v1/types/__init__.py @@ -32,6 +32,7 @@ from .type import ( Type, StructType, + TypeCode, ) from .result_set import ( ResultSet, @@ -63,7 +64,6 @@ RollbackRequest, ) - __all__ = ( "KeyRange", "KeySet", @@ -75,6 +75,7 @@ "TransactionSelector", "Type", "StructType", + "TypeCode", "ResultSet", "PartialResultSet", "ResultSetMetadata", diff --git a/samples/samples/autocommit.py b/samples/samples/autocommit.py new file mode 100644 index 0000000000..873ed2b7bd --- /dev/null +++ b/samples/samples/autocommit.py @@ -0,0 +1,64 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import argparse + +from google.cloud.spanner_dbapi import connect + + +def enable_autocommit_mode(instance_id, database_id): + """Enables autocommit mode.""" + # [START spanner_enable_autocommit_mode] + + connection = connect(instance_id, database_id) + connection.autocommit = True + print("Autocommit mode is enabled.") + + cursor = connection.cursor() + + cursor.execute( + """CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) PRIMARY KEY (SingerId)""" + ) + + cursor.execute( + """INSERT INTO Singers (SingerId, FirstName, LastName) VALUES + (12, 'Melissa', 'Garcia'), + (13, 'Russell', 'Morales'), + (14, 'Jacqueline', 'Long'), + (15, 'Dylan', 'Shaw')""" + ) + + cursor.execute("""SELECT * FROM Singers WHERE SingerId = 13""") + + print("SingerId: {}, AlbumId: {}, AlbumTitle: {}".format(*cursor.fetchone())) + + connection.close() + # [END spanner_enable_autocommit_mode] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("instance_id", help="Your Cloud Spanner instance ID.") + parser.add_argument( + "--database-id", + help="Your Cloud Spanner database ID.", + default="example_db", + ) + subparsers = parser.add_subparsers(dest="command") + subparsers.add_parser("enable_autocommit_mode", help=enable_autocommit_mode.__doc__) + args = parser.parse_args() + if args.command == "enable_autocommit_mode": + enable_autocommit_mode(args.instance_id, args.database_id) + else: + print(f"Command {args.command} did not match expected commands.") diff --git a/samples/samples/autocommit_test.py b/samples/samples/autocommit_test.py new file mode 100644 index 0000000000..c906f060e0 --- /dev/null +++ b/samples/samples/autocommit_test.py @@ -0,0 +1,62 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import uuid + +from google.cloud import spanner +from google.cloud.spanner_dbapi import connect +import mock +import pytest + +import autocommit + + +def unique_instance_id(): + """Creates a unique id for the database.""" + return f"test-instance-{uuid.uuid4().hex[:10]}" + + +def unique_database_id(): + """Creates a unique id for the database.""" + return f"test-db-{uuid.uuid4().hex[:10]}" + + +INSTANCE_ID = unique_instance_id() +DATABASE_ID = unique_database_id() + + +@pytest.fixture(scope="module") +def spanner_instance(): + spanner_client = spanner.Client() + config_name = f"{spanner_client.project_name}/instanceConfigs/regional-us-central1" + + instance = spanner_client.instance(INSTANCE_ID, config_name) + op = instance.create() + op.result(120) # block until completion + yield instance + instance.delete() + + +@pytest.fixture(scope="module") +def database(spanner_instance): + """Creates a temporary database that is removed after testing.""" + db = spanner_instance.database(DATABASE_ID) + db.create() + yield db + db.drop() + + +def test_enable_autocommit_mode(capsys, database): + connection = connect(INSTANCE_ID, DATABASE_ID) + cursor = connection.cursor() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Cursor", return_value=cursor, + ): + autocommit.enable_autocommit_mode(INSTANCE_ID, DATABASE_ID) + out, _ = capsys.readouterr() + assert "Autocommit mode is enabled." in out + assert "SingerId: 13, AlbumId: Russell, AlbumTitle: Morales" in out diff --git a/samples/samples/backup_sample.py b/samples/samples/backup_sample.py index 29492c5872..5e2f51679b 100644 --- a/samples/samples/backup_sample.py +++ b/samples/samples/backup_sample.py @@ -56,7 +56,7 @@ def create_backup(instance_id, database_id, backup_id): # [END spanner_create_backup] -# [START spanner_restore_database] +# [START spanner_restore_backup] def restore_database(instance_id, new_database_id, backup_id): """Restores a database from a backup.""" spanner_client = spanner.Client() @@ -83,10 +83,10 @@ def restore_database(instance_id, new_database_id, backup_id): ) -# [END spanner_restore_database] +# [END spanner_restore_backup] -# [START spanner_cancel_backup] +# [START spanner_cancel_backup_create] def cancel_backup(instance_id, database_id, backup_id): spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) @@ -115,7 +115,7 @@ def cancel_backup(instance_id, database_id, backup_id): print("Backup creation was successfully cancelled.") -# [END spanner_cancel_backup] +# [END spanner_cancel_backup_create] # [START spanner_list_backup_operations] diff --git a/samples/samples/requirements.txt b/samples/samples/requirements.txt index daa9cd5a4f..816e298236 100644 --- a/samples/samples/requirements.txt +++ b/samples/samples/requirements.txt @@ -1,2 +1,2 @@ -google-cloud-spanner==2.0.0 +google-cloud-spanner==2.1.0 futures==3.3.0; python_version < "3" diff --git a/scripts/fixup_spanner_admin_database_v1_keywords.py b/scripts/fixup_spanner_admin_database_v1_keywords.py index 9f1a9bb9f1..96334a9f32 100644 --- a/scripts/fixup_spanner_admin_database_v1_keywords.py +++ b/scripts/fixup_spanner_admin_database_v1_keywords.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/scripts/fixup_spanner_admin_instance_v1_keywords.py b/scripts/fixup_spanner_admin_instance_v1_keywords.py index 0871592c96..eb5507ec97 100644 --- a/scripts/fixup_spanner_admin_instance_v1_keywords.py +++ b/scripts/fixup_spanner_admin_instance_v1_keywords.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/scripts/fixup_spanner_v1_keywords.py b/scripts/fixup_spanner_v1_keywords.py index 7c83aaf33d..bb76ae0e8c 100644 --- a/scripts/fixup_spanner_v1_keywords.py +++ b/scripts/fixup_spanner_v1_keywords.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/setup.py b/setup.py index 87f3e26874..28f21ad515 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ name = "google-cloud-spanner" description = "Cloud Spanner API client library" -version = "2.1.0" +version = "3.0.0" # Should be one of: # 'Development Status :: 3 - Alpha' # 'Development Status :: 4 - Beta' @@ -33,7 +33,7 @@ "google-cloud-core >= 1.4.1, < 2.0dev", "grpc-google-iam-v1 >= 0.12.3, < 0.13dev", "libcst >= 0.2.5", - "proto-plus == 1.11.0", + "proto-plus==1.13.0", "sqlparse >= 0.3.0", ] extras = { diff --git a/synth.metadata b/synth.metadata index bba4518649..99b49c42da 100644 --- a/synth.metadata +++ b/synth.metadata @@ -3,8 +3,16 @@ { "git": { "name": ".", - "remote": "git@github.com:larkee/python-spanner.git", - "sha": "1d3e65af688c31937b0110223679607c19c328e9" + "remote": "https://github.com/googleapis/python-spanner.git", + "sha": "2faf01b135360586ef27c66976646593fd85fd1e" + } + }, + { + "git": { + "name": "googleapis", + "remote": "https://github.com/googleapis/googleapis.git", + "sha": "dd372aa22ded7a8ba6f0e03a80e06358a3fa0907", + "internalRef": "347055288" } }, { @@ -50,5 +58,144 @@ "generator": "bazel" } } + ], + "generatedFiles": [ + ".flake8", + ".github/CONTRIBUTING.md", + ".github/ISSUE_TEMPLATE/bug_report.md", + ".github/ISSUE_TEMPLATE/feature_request.md", + ".github/ISSUE_TEMPLATE/support_request.md", + ".github/PULL_REQUEST_TEMPLATE.md", + ".github/release-please.yml", + ".github/snippet-bot.yml", + ".gitignore", + ".kokoro/build.sh", + ".kokoro/continuous/common.cfg", + ".kokoro/continuous/continuous.cfg", + ".kokoro/docker/docs/Dockerfile", + ".kokoro/docker/docs/fetch_gpg_keys.sh", + ".kokoro/docs/common.cfg", + ".kokoro/docs/docs-presubmit.cfg", + ".kokoro/docs/docs.cfg", + ".kokoro/populate-secrets.sh", + ".kokoro/presubmit/common.cfg", + ".kokoro/presubmit/presubmit.cfg", + ".kokoro/publish-docs.sh", + ".kokoro/release.sh", + ".kokoro/release/common.cfg", + ".kokoro/release/release.cfg", + ".kokoro/samples/lint/common.cfg", + ".kokoro/samples/lint/continuous.cfg", + ".kokoro/samples/lint/periodic.cfg", + ".kokoro/samples/lint/presubmit.cfg", + ".kokoro/samples/python3.6/common.cfg", + ".kokoro/samples/python3.6/continuous.cfg", + ".kokoro/samples/python3.6/periodic.cfg", + ".kokoro/samples/python3.6/presubmit.cfg", + ".kokoro/samples/python3.7/common.cfg", + ".kokoro/samples/python3.7/continuous.cfg", + ".kokoro/samples/python3.7/periodic.cfg", + ".kokoro/samples/python3.7/presubmit.cfg", + ".kokoro/samples/python3.8/common.cfg", + ".kokoro/samples/python3.8/continuous.cfg", + ".kokoro/samples/python3.8/periodic.cfg", + ".kokoro/samples/python3.8/presubmit.cfg", + ".kokoro/test-samples.sh", + ".kokoro/trampoline.sh", + ".kokoro/trampoline_v2.sh", + ".trampolinerc", + "CODE_OF_CONDUCT.md", + "CONTRIBUTING.rst", + "LICENSE", + "MANIFEST.in", + "docs/_static/custom.css", + "docs/_templates/layout.html", + "docs/conf.py", + "docs/multiprocessing.rst", + "docs/spanner_admin_database_v1/services.rst", + "docs/spanner_admin_database_v1/types.rst", + "docs/spanner_admin_instance_v1/services.rst", + "docs/spanner_admin_instance_v1/types.rst", + "docs/spanner_v1/services.rst", + "docs/spanner_v1/types.rst", + "google/cloud/spanner_admin_database_v1/__init__.py", + "google/cloud/spanner_admin_database_v1/proto/backup.proto", + "google/cloud/spanner_admin_database_v1/proto/common.proto", + "google/cloud/spanner_admin_database_v1/proto/spanner_database_admin.proto", + "google/cloud/spanner_admin_database_v1/py.typed", + "google/cloud/spanner_admin_database_v1/services/__init__.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/__init__.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/async_client.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/client.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/pagers.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/transports/__init__.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/transports/base.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc.py", + "google/cloud/spanner_admin_database_v1/services/database_admin/transports/grpc_asyncio.py", + "google/cloud/spanner_admin_database_v1/types/__init__.py", + "google/cloud/spanner_admin_database_v1/types/backup.py", + "google/cloud/spanner_admin_database_v1/types/common.py", + "google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py", + "google/cloud/spanner_admin_instance_v1/__init__.py", + "google/cloud/spanner_admin_instance_v1/proto/spanner_instance_admin.proto", + "google/cloud/spanner_admin_instance_v1/py.typed", + "google/cloud/spanner_admin_instance_v1/services/__init__.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/__init__.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/async_client.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/client.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/pagers.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/__init__.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/base.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc.py", + "google/cloud/spanner_admin_instance_v1/services/instance_admin/transports/grpc_asyncio.py", + "google/cloud/spanner_admin_instance_v1/types/__init__.py", + "google/cloud/spanner_admin_instance_v1/types/spanner_instance_admin.py", + "google/cloud/spanner_v1/proto/keys.proto", + "google/cloud/spanner_v1/proto/mutation.proto", + "google/cloud/spanner_v1/proto/query_plan.proto", + "google/cloud/spanner_v1/proto/result_set.proto", + "google/cloud/spanner_v1/proto/spanner.proto", + "google/cloud/spanner_v1/proto/transaction.proto", + "google/cloud/spanner_v1/proto/type.proto", + "google/cloud/spanner_v1/py.typed", + "google/cloud/spanner_v1/services/__init__.py", + "google/cloud/spanner_v1/services/spanner/__init__.py", + "google/cloud/spanner_v1/services/spanner/async_client.py", + "google/cloud/spanner_v1/services/spanner/client.py", + "google/cloud/spanner_v1/services/spanner/pagers.py", + "google/cloud/spanner_v1/services/spanner/transports/__init__.py", + "google/cloud/spanner_v1/services/spanner/transports/base.py", + "google/cloud/spanner_v1/services/spanner/transports/grpc.py", + "google/cloud/spanner_v1/services/spanner/transports/grpc_asyncio.py", + "google/cloud/spanner_v1/types/__init__.py", + "google/cloud/spanner_v1/types/keys.py", + "google/cloud/spanner_v1/types/mutation.py", + "google/cloud/spanner_v1/types/query_plan.py", + "google/cloud/spanner_v1/types/result_set.py", + "google/cloud/spanner_v1/types/spanner.py", + "google/cloud/spanner_v1/types/transaction.py", + "google/cloud/spanner_v1/types/type.py", + "renovate.json", + "samples/AUTHORING_GUIDE.md", + "samples/CONTRIBUTING.md", + "samples/samples/noxfile.py", + "scripts/decrypt-secrets.sh", + "scripts/fixup_spanner_admin_database_v1_keywords.py", + "scripts/fixup_spanner_admin_instance_v1_keywords.py", + "scripts/fixup_spanner_v1_keywords.py", + "scripts/readme-gen/readme_gen.py", + "scripts/readme-gen/templates/README.tmpl.rst", + "scripts/readme-gen/templates/auth.tmpl.rst", + "scripts/readme-gen/templates/auth_api_key.tmpl.rst", + "scripts/readme-gen/templates/install_deps.tmpl.rst", + "scripts/readme-gen/templates/install_portaudio.tmpl.rst", + "setup.cfg", + "testing/.gitignore", + "tests/unit/gapic/spanner_admin_database_v1/__init__.py", + "tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py", + "tests/unit/gapic/spanner_admin_instance_v1/__init__.py", + "tests/unit/gapic/spanner_admin_instance_v1/test_instance_admin.py", + "tests/unit/gapic/spanner_v1/__init__.py", + "tests/unit/gapic/spanner_v1/test_spanner.py" ] } \ No newline at end of file diff --git a/testing/constraints-3.10.txt b/testing/constraints-3.10.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/testing/constraints-3.11.txt b/testing/constraints-3.11.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt new file mode 100644 index 0000000000..050e9c7a18 --- /dev/null +++ b/testing/constraints-3.6.txt @@ -0,0 +1,16 @@ +# This constraints file is used to check that lower bounds +# are correct in setup.py +# List *all* library dependencies and extras in this file. +# Pin the version to the lower bound. +# +# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", +# Then this file should have foo==1.14.0 +google-api-core==1.22.0 +google-cloud-core==1.4.1 +grpc-google-iam-v1==0.12.3 +libcst==0.2.5 +proto-plus==1.13.0 +sqlparse==0.3.0 +opentelemetry-api==0.11b0 +opentelemetry-sdk==0.11b0 +opentelemetry-instrumentation==0.11b0 \ No newline at end of file diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 1ba9b59163..495824044b 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -35,11 +35,13 @@ from google.cloud.spanner_v1 import Type from google.cloud._helpers import UTC +from google.cloud.spanner_v1 import BurstyPool +from google.cloud.spanner_v1 import COMMIT_TIMESTAMP from google.cloud.spanner_v1 import Client from google.cloud.spanner_v1 import KeyRange from google.cloud.spanner_v1 import KeySet -from google.cloud.spanner_v1 import BurstyPool -from google.cloud.spanner_v1 import COMMIT_TIMESTAMP +from google.cloud.spanner_v1.instance import Backup +from google.cloud.spanner_v1.instance import Instance from test_utils.retry import RetryErrors from test_utils.retry import RetryInstanceState @@ -113,6 +115,22 @@ def setUpModule(): instances = retry(_list_instances)() EXISTING_INSTANCES[:] = instances + # Delete test instances that are older than an hour. + cutoff = int(time.time()) - 1 * 60 * 60 + instance_pbs = Config.CLIENT.list_instances("labels.python-spanner-systests:true") + for instance_pb in instance_pbs: + instance = Instance.from_pb(instance_pb, Config.CLIENT) + if "created" not in instance.labels: + continue + create_time = int(instance.labels["created"]) + if create_time > cutoff: + continue + # Instance cannot be deleted while backups exist. + for backup_pb in instance.list_backups(): + backup = Backup.from_pb(backup_pb, instance) + backup.delete() + instance.delete() + if CREATE_INSTANCE: if not USE_EMULATOR: # Defend against back-end returning configs for regions we aren't @@ -124,8 +142,12 @@ def setUpModule(): Config.INSTANCE_CONFIG = configs[0] config_name = configs[0].name + create_time = str(int(time.time())) + labels = {"python-spanner-systests": "true", "created": create_time} - Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID, config_name) + Config.INSTANCE = Config.CLIENT.instance( + INSTANCE_ID, config_name, labels=labels + ) created_op = Config.INSTANCE.create() created_op.result(30) # block until completion @@ -466,8 +488,10 @@ def setUpClass(cls): current_config = Config.INSTANCE.configuration_name same_config_instance_id = "same-config" + unique_resource_id("-") + create_time = str(int(time.time())) + labels = {"python-spanner-systests": "true", "created": create_time} cls._same_config_instance = Config.CLIENT.instance( - same_config_instance_id, current_config + same_config_instance_id, current_config, labels=labels ) op = cls._same_config_instance.create() op.result(30) @@ -483,8 +507,10 @@ def setUpClass(cls): cls._diff_config_instance = None if len(diff_configs) > 0: diff_config_instance_id = "diff-config" + unique_resource_id("-") + create_time = str(int(time.time())) + labels = {"python-spanner-systests": "true", "created": create_time} cls._diff_config_instance = Config.CLIENT.instance( - diff_config_instance_id, diff_configs[0] + diff_config_instance_id, diff_configs[0], labels=labels ) op = cls._diff_config_instance.create() op.result(30) @@ -918,9 +944,9 @@ def test_batch_insert_then_read(self): ) def test_batch_insert_then_read_string_array_of_string(self): - TABLE = "string_plus_array_of_string" - COLUMNS = ["id", "name", "tags"] - ROWDATA = [ + table = "string_plus_array_of_string" + columns = ["id", "name", "tags"] + rowdata = [ (0, None, None), (1, "phred", ["yabba", "dabba", "do"]), (2, "bharney", []), @@ -930,12 +956,12 @@ def test_batch_insert_then_read_string_array_of_string(self): retry(self._db.reload)() with self._db.batch() as batch: - batch.delete(TABLE, self.ALL) - batch.insert(TABLE, COLUMNS, ROWDATA) + batch.delete(table, self.ALL) + batch.insert(table, columns, rowdata) with self._db.snapshot(read_timestamp=batch.committed) as snapshot: - rows = list(snapshot.read(TABLE, COLUMNS, self.ALL)) - self._check_rows_data(rows, expected=ROWDATA) + rows = list(snapshot.read(table, columns, self.ALL)) + self._check_rows_data(rows, expected=rowdata) def test_batch_insert_then_read_all_datatypes(self): retry = RetryInstanceState(_has_all_ddl) @@ -1549,14 +1575,14 @@ def _read_w_concurrent_update(self, transaction, pkey): transaction.update(COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]]) def test_transaction_read_w_concurrent_updates(self): - PKEY = "read_w_concurrent_updates" - self._transaction_concurrency_helper(self._read_w_concurrent_update, PKEY) + pkey = "read_w_concurrent_updates" + self._transaction_concurrency_helper(self._read_w_concurrent_update, pkey) def _query_w_concurrent_update(self, transaction, pkey): - SQL = "SELECT * FROM counters WHERE name = @name" + sql = "SELECT * FROM counters WHERE name = @name" rows = list( transaction.execute_sql( - SQL, params={"name": pkey}, param_types={"name": param_types.STRING} + sql, params={"name": pkey}, param_types={"name": param_types.STRING} ) ) self.assertEqual(len(rows), 1) @@ -1564,8 +1590,8 @@ def _query_w_concurrent_update(self, transaction, pkey): transaction.update(COUNTERS_TABLE, COUNTERS_COLUMNS, [[pkey, value + 1]]) def test_transaction_query_w_concurrent_updates(self): - PKEY = "query_w_concurrent_updates" - self._transaction_concurrency_helper(self._query_w_concurrent_update, PKEY) + pkey = "query_w_concurrent_updates" + self._transaction_concurrency_helper(self._query_w_concurrent_update, pkey) @unittest.skipIf(USE_EMULATOR, "Skipping concurrent transactions") def test_transaction_read_w_abort(self): @@ -1663,9 +1689,9 @@ def test_snapshot_read_w_various_staleness(self): from datetime import datetime from google.cloud._helpers import UTC - ROW_COUNT = 400 - committed = self._set_up_table(ROW_COUNT) - all_data_rows = list(self._row_data(ROW_COUNT)) + row_count = 400 + committed = self._set_up_table(row_count) + all_data_rows = list(self._row_data(row_count)) before_reads = datetime.utcnow().replace(tzinfo=UTC) @@ -1697,9 +1723,9 @@ def test_snapshot_read_w_various_staleness(self): self._check_row_data(rows, all_data_rows) def test_multiuse_snapshot_read_isolation_strong(self): - ROW_COUNT = 40 - self._set_up_table(ROW_COUNT) - all_data_rows = list(self._row_data(ROW_COUNT)) + row_count = 40 + self._set_up_table(row_count) + all_data_rows = list(self._row_data(row_count)) with self._db.snapshot(multi_use=True) as strong: before = list(strong.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_row_data(before, all_data_rows) @@ -1711,9 +1737,9 @@ def test_multiuse_snapshot_read_isolation_strong(self): self._check_row_data(after, all_data_rows) def test_multiuse_snapshot_read_isolation_read_timestamp(self): - ROW_COUNT = 40 - committed = self._set_up_table(ROW_COUNT) - all_data_rows = list(self._row_data(ROW_COUNT)) + row_count = 40 + committed = self._set_up_table(row_count) + all_data_rows = list(self._row_data(row_count)) with self._db.snapshot(read_timestamp=committed, multi_use=True) as read_ts: @@ -1727,10 +1753,10 @@ def test_multiuse_snapshot_read_isolation_read_timestamp(self): self._check_row_data(after, all_data_rows) def test_multiuse_snapshot_read_isolation_exact_staleness(self): - ROW_COUNT = 40 + row_count = 40 - self._set_up_table(ROW_COUNT) - all_data_rows = list(self._row_data(ROW_COUNT)) + self._set_up_table(row_count) + all_data_rows = list(self._row_data(row_count)) time.sleep(1) delta = datetime.timedelta(microseconds=1000) @@ -1747,7 +1773,7 @@ def test_multiuse_snapshot_read_isolation_exact_staleness(self): self._check_row_data(after, all_data_rows) def test_read_w_index(self): - ROW_COUNT = 2000 + row_count = 2000 # Indexed reads cannot return non-indexed columns MY_COLUMNS = self.COLUMNS[0], self.COLUMNS[2] EXTRA_DDL = ["CREATE INDEX contacts_by_last_name ON contacts(last_name)"] @@ -1763,7 +1789,7 @@ def test_read_w_index(self): # We want to make sure the operation completes. operation.result(30) # raises on failure / timeout. - committed = self._set_up_table(ROW_COUNT, database=temp_db) + committed = self._set_up_table(row_count, database=temp_db) with temp_db.snapshot(read_timestamp=committed) as snapshot: rows = list( @@ -1773,36 +1799,36 @@ def test_read_w_index(self): ) expected = list( - reversed([(row[0], row[2]) for row in self._row_data(ROW_COUNT)]) + reversed([(row[0], row[2]) for row in self._row_data(row_count)]) ) self._check_rows_data(rows, expected) def test_read_w_single_key(self): # [START spanner_test_single_key_read] - ROW_COUNT = 40 - committed = self._set_up_table(ROW_COUNT) + row_count = 40 + committed = self._set_up_table(row_count) with self._db.snapshot(read_timestamp=committed) as snapshot: rows = list(snapshot.read(self.TABLE, self.COLUMNS, KeySet(keys=[(0,)]))) - all_data_rows = list(self._row_data(ROW_COUNT)) + all_data_rows = list(self._row_data(row_count)) expected = [all_data_rows[0]] self._check_row_data(rows, expected) # [END spanner_test_single_key_read] def test_empty_read(self): # [START spanner_test_empty_read] - ROW_COUNT = 40 - self._set_up_table(ROW_COUNT) + row_count = 40 + self._set_up_table(row_count) with self._db.snapshot() as snapshot: rows = list(snapshot.read(self.TABLE, self.COLUMNS, KeySet(keys=[(40,)]))) self._check_row_data(rows, []) # [END spanner_test_empty_read] def test_read_w_multiple_keys(self): - ROW_COUNT = 40 + row_count = 40 indices = [0, 5, 17] - committed = self._set_up_table(ROW_COUNT) + committed = self._set_up_table(row_count) with self._db.snapshot(read_timestamp=committed) as snapshot: rows = list( @@ -1813,58 +1839,58 @@ def test_read_w_multiple_keys(self): ) ) - all_data_rows = list(self._row_data(ROW_COUNT)) + all_data_rows = list(self._row_data(row_count)) expected = [row for row in all_data_rows if row[0] in indices] self._check_row_data(rows, expected) def test_read_w_limit(self): - ROW_COUNT = 3000 - LIMIT = 100 - committed = self._set_up_table(ROW_COUNT) + row_count = 3000 + limit = 100 + committed = self._set_up_table(row_count) with self._db.snapshot(read_timestamp=committed) as snapshot: - rows = list(snapshot.read(self.TABLE, self.COLUMNS, self.ALL, limit=LIMIT)) + rows = list(snapshot.read(self.TABLE, self.COLUMNS, self.ALL, limit=limit)) - all_data_rows = list(self._row_data(ROW_COUNT)) - expected = all_data_rows[:LIMIT] + all_data_rows = list(self._row_data(row_count)) + expected = all_data_rows[:limit] self._check_row_data(rows, expected) def test_read_w_ranges(self): - ROW_COUNT = 3000 - START = 1000 - END = 2000 - committed = self._set_up_table(ROW_COUNT) + row_count = 3000 + start = 1000 + end = 2000 + committed = self._set_up_table(row_count) with self._db.snapshot(read_timestamp=committed, multi_use=True) as snapshot: - all_data_rows = list(self._row_data(ROW_COUNT)) + all_data_rows = list(self._row_data(row_count)) - single_key = KeyRange(start_closed=[START], end_open=[START + 1]) + single_key = KeyRange(start_closed=[start], end_open=[start + 1]) keyset = KeySet(ranges=(single_key,)) rows = list(snapshot.read(self.TABLE, self.COLUMNS, keyset)) - expected = all_data_rows[START : START + 1] + expected = all_data_rows[start : start + 1] self._check_rows_data(rows, expected) - closed_closed = KeyRange(start_closed=[START], end_closed=[END]) + closed_closed = KeyRange(start_closed=[start], end_closed=[end]) keyset = KeySet(ranges=(closed_closed,)) rows = list(snapshot.read(self.TABLE, self.COLUMNS, keyset)) - expected = all_data_rows[START : END + 1] + expected = all_data_rows[start : end + 1] self._check_row_data(rows, expected) - closed_open = KeyRange(start_closed=[START], end_open=[END]) + closed_open = KeyRange(start_closed=[start], end_open=[end]) keyset = KeySet(ranges=(closed_open,)) rows = list(snapshot.read(self.TABLE, self.COLUMNS, keyset)) - expected = all_data_rows[START:END] + expected = all_data_rows[start:end] self._check_row_data(rows, expected) - open_open = KeyRange(start_open=[START], end_open=[END]) + open_open = KeyRange(start_open=[start], end_open=[end]) keyset = KeySet(ranges=(open_open,)) rows = list(snapshot.read(self.TABLE, self.COLUMNS, keyset)) - expected = all_data_rows[START + 1 : END] + expected = all_data_rows[start + 1 : end] self._check_row_data(rows, expected) - open_closed = KeyRange(start_open=[START], end_closed=[END]) + open_closed = KeyRange(start_open=[start], end_closed=[end]) keyset = KeySet(ranges=(open_closed,)) rows = list(snapshot.read(self.TABLE, self.COLUMNS, keyset)) - expected = all_data_rows[START + 1 : END + 1] + expected = all_data_rows[start + 1 : end + 1] self._check_row_data(rows, expected) def test_read_partial_range_until_end(self): @@ -2108,8 +2134,8 @@ def test_partition_read_w_index(self): batch_txn.close() def test_execute_sql_w_manual_consume(self): - ROW_COUNT = 3000 - committed = self._set_up_table(ROW_COUNT) + row_count = 3000 + committed = self._set_up_table(row_count) with self._db.snapshot(read_timestamp=committed) as snapshot: streamed = snapshot.execute_sql(self.SQL) @@ -2133,9 +2159,9 @@ def _check_sql_results( self._check_rows_data(rows, expected=expected) def test_multiuse_snapshot_execute_sql_isolation_strong(self): - ROW_COUNT = 40 - self._set_up_table(ROW_COUNT) - all_data_rows = list(self._row_data(ROW_COUNT)) + row_count = 40 + self._set_up_table(row_count) + all_data_rows = list(self._row_data(row_count)) with self._db.snapshot(multi_use=True) as strong: before = list(strong.execute_sql(self.SQL)) @@ -2148,7 +2174,7 @@ def test_multiuse_snapshot_execute_sql_isolation_strong(self): self._check_row_data(after, all_data_rows) def test_execute_sql_returning_array_of_struct(self): - SQL = ( + sql = ( "SELECT ARRAY(SELECT AS STRUCT C1, C2 " "FROM (SELECT 'a' AS C1, 1 AS C2 " "UNION ALL SELECT 'b' AS C1, 2 AS C2) " @@ -2156,14 +2182,14 @@ def test_execute_sql_returning_array_of_struct(self): ) self._check_sql_results( self._db, - sql=SQL, + sql=sql, params=None, param_types=None, expected=[[[["a", 1], ["b", 2]]]], ) def test_execute_sql_returning_empty_array_of_struct(self): - SQL = ( + sql = ( "SELECT ARRAY(SELECT AS STRUCT C1, C2 " "FROM (SELECT 2 AS C1) X " "JOIN (SELECT 1 AS C2) Y " @@ -2173,7 +2199,7 @@ def test_execute_sql_returning_empty_array_of_struct(self): self._db.snapshot(multi_use=True) self._check_sql_results( - self._db, sql=SQL, params=None, param_types=None, expected=[[[]]] + self._db, sql=sql, params=None, param_types=None, expected=[[[]]] ) def test_invalid_type(self): @@ -2338,11 +2364,11 @@ def test_execute_sql_w_numeric_bindings(self): self._bind_test_helper(TypeCode.NUMERIC, NUMERIC_1, [NUMERIC_1, NUMERIC_2]) def test_execute_sql_w_query_param_struct(self): - NAME = "Phred" - COUNT = 123 - SIZE = 23.456 - HEIGHT = 188.0 - WEIGHT = 97.6 + name = "Phred" + count = 123 + size = 23.456 + height = 188.0 + weight = 97.6 record_type = param_types.Struct( [ @@ -2395,9 +2421,9 @@ def test_execute_sql_w_query_param_struct(self): self._check_sql_results( self._db, sql="SELECT @r.name, @r.count, @r.size, @r.nested.weight", - params={"r": (NAME, COUNT, SIZE, (HEIGHT, WEIGHT))}, + params={"r": (name, count, size, (height, weight))}, param_types={"r": record_type}, - expected=[(NAME, COUNT, SIZE, WEIGHT)], + expected=[(name, count, size, weight)], order=False, ) diff --git a/tests/system/test_system_dbapi.py b/tests/system/test_system_dbapi.py index be8e9f2a26..baeadd2c44 100644 --- a/tests/system/test_system_dbapi.py +++ b/tests/system/test_system_dbapi.py @@ -15,12 +15,15 @@ import hashlib import os import pickle +import time import unittest from google.api_core import exceptions -from google.cloud.spanner_v1 import Client from google.cloud.spanner_v1 import BurstyPool +from google.cloud.spanner_v1 import Client +from google.cloud.spanner_v1.instance import Backup +from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_dbapi.connection import Connection @@ -53,6 +56,23 @@ def setUpModule(): instances = retry(_list_instances)() EXISTING_INSTANCES[:] = instances + # Delete test instances that are older than an hour. + cutoff = int(time.time()) - 1 * 60 * 60 + for instance_pb in Config.CLIENT.list_instances( + "labels.python-spanner-dbapi-systests:true" + ): + instance = Instance.from_pb(instance_pb, Config.CLIENT) + if "created" not in instance.labels: + continue + create_time = int(instance.labels["created"]) + if create_time > cutoff: + continue + # Instance cannot be deleted while backups exist. + for backup_pb in instance.list_backups(): + backup = Backup.from_pb(backup_pb, instance) + backup.delete() + instance.delete() + if CREATE_INSTANCE: if not USE_EMULATOR: # Defend against back-end returning configs for regions we aren't @@ -64,8 +84,12 @@ def setUpModule(): Config.INSTANCE_CONFIG = configs[0] config_name = configs[0].name + create_time = str(int(time.time())) + labels = {"python-spanner-dbapi-systests": "true", "created": create_time} - Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID, config_name) + Config.INSTANCE = Config.CLIENT.instance( + INSTANCE_ID, config_name, labels=labels + ) created_op = Config.INSTANCE.create() created_op.result(30) # block until completion @@ -305,6 +329,46 @@ def test_results_checksum(self): self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest()) + def test_execute_many(self): + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@example.com'), + (2, 'first-name2', 'last-name2', 'test.email2@example.com') + """ + ) + conn.commit() + + cursor.executemany( + """ +SELECT * FROM contacts WHERE contact_id = @a1 +""", + ({"a1": 1}, {"a1": 2}), + ) + res = cursor.fetchall() + conn.commit() + + self.assertEqual(len(res), 2) + self.assertEqual(res[0][0], 1) + self.assertEqual(res[1][0], 2) + + # checking that execute() and executemany() + # results are not mixed together + cursor.execute( + """ +SELECT * FROM contacts WHERE contact_id = 1 +""", + ) + res = cursor.fetchone() + conn.commit() + + self.assertEqual(res[0], 1) + conn.close() + def clear_table(transaction): """Clear the test table.""" diff --git a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py index ea79f63e86..7779e49659 100644 --- a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py +++ b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py @@ -4725,6 +4725,7 @@ def test_database_admin_grpc_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_database_admin_grpc_asyncio_transport_channel(): @@ -4736,6 +4737,7 @@ def test_database_admin_grpc_asyncio_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( @@ -4782,8 +4784,13 @@ def test_database_admin_transport_channel_mtls_with_client_cert_source(transport ), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( @@ -4825,6 +4832,10 @@ def test_database_admin_transport_channel_mtls_with_adc(transport_class): ), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/spanner_admin_instance_v1/test_instance_admin.py b/tests/unit/gapic/spanner_admin_instance_v1/test_instance_admin.py index 0db8185b79..bb4e98d401 100644 --- a/tests/unit/gapic/spanner_admin_instance_v1/test_instance_admin.py +++ b/tests/unit/gapic/spanner_admin_instance_v1/test_instance_admin.py @@ -3082,6 +3082,7 @@ def test_instance_admin_grpc_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_instance_admin_grpc_asyncio_transport_channel(): @@ -3093,6 +3094,7 @@ def test_instance_admin_grpc_asyncio_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( @@ -3139,8 +3141,13 @@ def test_instance_admin_transport_channel_mtls_with_client_cert_source(transport ), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( @@ -3182,6 +3189,10 @@ def test_instance_admin_transport_channel_mtls_with_adc(transport_class): ), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/spanner_v1/test_spanner.py b/tests/unit/gapic/spanner_v1/test_spanner.py index d891f27d94..2bb2324fac 100644 --- a/tests/unit/gapic/spanner_v1/test_spanner.py +++ b/tests/unit/gapic/spanner_v1/test_spanner.py @@ -3190,6 +3190,7 @@ def test_spanner_grpc_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_spanner_grpc_asyncio_transport_channel(): @@ -3201,6 +3202,7 @@ def test_spanner_grpc_asyncio_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( @@ -3244,8 +3246,13 @@ def test_spanner_transport_channel_mtls_with_client_cert_source(transport_class) ), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( @@ -3284,6 +3291,10 @@ def test_spanner_transport_channel_mtls_with_adc(transport_class): ), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 213eb24d84..a338055a2c 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -15,7 +15,6 @@ """Cloud Spanner DB-API Connection class unit tests.""" import mock -import sys import unittest import warnings @@ -51,25 +50,57 @@ def _make_connection(self): database = instance.database(self.DATABASE) return Connection(instance, database) - @unittest.skipIf(sys.version_info[0] < 3, "Python 2 patching is outdated") - def test_property_autocommit_setter(self): - from google.cloud.spanner_dbapi import Connection - - connection = Connection(self.INSTANCE, self.DATABASE) + def test_autocommit_setter_transaction_not_started(self): + connection = self._make_connection() with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.commit" ) as mock_commit: connection.autocommit = True - mock_commit.assert_called_once_with() - self.assertEqual(connection._autocommit, True) + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.commit" ) as mock_commit: connection.autocommit = False mock_commit.assert_not_called() - self.assertEqual(connection._autocommit, False) + self.assertFalse(connection._autocommit) + + def test_autocommit_setter_transaction_started(self): + connection = self._make_connection() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection._transaction = mock.Mock(committed=False, rolled_back=False) + + connection.autocommit = True + mock_commit.assert_called_once() + self.assertTrue(connection._autocommit) + + def test_autocommit_setter_transaction_started_commited_rolled_back(self): + connection = self._make_connection() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection._transaction = mock.Mock(committed=True, rolled_back=False) + + connection.autocommit = True + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) + + connection.autocommit = False + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection._transaction = mock.Mock(committed=False, rolled_back=True) + + connection.autocommit = True + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) def test_property_database(self): from google.cloud.spanner_v1.database import Database @@ -166,7 +197,9 @@ def test_commit(self, mock_warn): connection.commit() mock_release.assert_not_called() - connection._transaction = mock_transaction = mock.MagicMock() + connection._transaction = mock_transaction = mock.MagicMock( + rolled_back=False, committed=False + ) mock_transaction.commit = mock_commit = mock.MagicMock() with mock.patch( @@ -316,7 +349,7 @@ def test_run_statement_remember_statements(self): connection = self._make_connection() - statement = Statement(sql, params, param_types, ResultsChecksum(),) + statement = Statement(sql, params, param_types, ResultsChecksum(), False) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): @@ -338,7 +371,7 @@ def test_run_statement_dont_remember_retried_statements(self): connection = self._make_connection() - statement = Statement(sql, params, param_types, ResultsChecksum(),) + statement = Statement(sql, params, param_types, ResultsChecksum(), False) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): @@ -352,7 +385,7 @@ def test_clear_statements_on_commit(self): cleared, when the transaction is commited. """ connection = self._make_connection() - connection._transaction = mock.Mock() + connection._transaction = mock.Mock(rolled_back=False, committed=False) connection._statements = [{}, {}] self.assertEqual(len(connection._statements), 2) @@ -390,7 +423,7 @@ def test_retry_transaction(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = Statement("SELECT 1", [], {}, checksum,) + statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with mock.patch( @@ -423,7 +456,7 @@ def test_retry_transaction_checksum_mismatch(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = Statement("SELECT 1", [], {}, checksum,) + statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with mock.patch( @@ -453,9 +486,9 @@ def test_commit_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) - connection._transaction = mock.Mock() + connection._transaction = mock.Mock(rolled_back=False, committed=False) with mock.patch.object( connection._transaction, "commit", side_effect=(Aborted("Aborted"), None), @@ -507,7 +540,7 @@ def test_retry_aborted_retry(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) metadata_mock = mock.Mock() diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 43fc077abe..9f0510c4ab 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -126,7 +126,7 @@ def test_execute_attribute_error(self): cursor = self._make_one(connection) with self.assertRaises(AttributeError): - cursor.execute(sql="") + cursor.execute(sql="SELECT 1") def test_execute_autocommit_off(self): from google.cloud.spanner_dbapi.utils import PeekIterator @@ -257,6 +257,22 @@ def test_executemany_on_closed_cursor(self): with self.assertRaises(InterfaceError): cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ()) + def test_executemany_DLL(self): + from google.cloud.spanner_dbapi import connect, ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + + with self.assertRaises(ProgrammingError): + cursor.executemany("""DROP DATABASE database_name""", ()) + def test_executemany(self): from google.cloud.spanner_dbapi import connect @@ -272,6 +288,9 @@ def test_executemany(self): connection = connect("test-instance", "test-database") cursor = connection.cursor() + cursor._result_set = [1, 2, 3] + cursor._itr = iter([1, 2, 3]) + with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.execute" ) as execute_mock: @@ -512,7 +531,7 @@ def test_fetchone_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) with mock.patch( @@ -551,7 +570,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum,) + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) with mock.patch( diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 6d89a8a46a..3713ac11a8 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -391,7 +391,6 @@ def test_get_param_types_none(self): @unittest.skipIf(skip_condition, skip_message) def test_ensure_where_clause(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause cases = ( @@ -409,8 +408,7 @@ def test_ensure_where_clause(self): for sql in err_cases: with self.subTest(sql=sql): - with self.assertRaises(ProgrammingError): - ensure_where_clause(sql) + self.assertEqual(ensure_where_clause(sql), sql + " WHERE 1=1") @unittest.skipIf(skip_condition, skip_message) def test_escape_name(self): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index a3001e61ae..9c260c5f95 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -37,6 +37,7 @@ class TestClient(unittest.TestCase): INSTANCE_NAME = "%s/instances/%s" % (PATH, INSTANCE_ID) DISPLAY_NAME = "display-name" NODE_COUNT = 5 + LABELS = {"test": "true"} TIMEOUT_SECONDS = 80 def _get_target_class(self): @@ -518,6 +519,7 @@ def test_instance_factory_defaults(self): self.assertIsNone(instance.configuration_name) self.assertEqual(instance.display_name, self.INSTANCE_ID) self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT) + self.assertEqual(instance.labels, {}) self.assertIs(instance._client, client) def test_instance_factory_explicit(self): @@ -531,6 +533,7 @@ def test_instance_factory_explicit(self): self.CONFIGURATION_NAME, display_name=self.DISPLAY_NAME, node_count=self.NODE_COUNT, + labels=self.LABELS, ) self.assertIsInstance(instance, Instance) @@ -538,6 +541,7 @@ def test_instance_factory_explicit(self): self.assertEqual(instance.configuration_name, self.CONFIGURATION_NAME) self.assertEqual(instance.display_name, self.DISPLAY_NAME) self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.labels, self.LABELS) self.assertIs(instance._client, client) def test_list_instances(self): diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 0694d438a2..edd8249c67 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -38,6 +38,7 @@ class TestInstance(unittest.TestCase): TIMEOUT_SECONDS = 1 DATABASE_ID = "database_id" DATABASE_NAME = "%s/databases/%s" % (INSTANCE_NAME, DATABASE_ID) + LABELS = {"test": "true"} def _getTargetClass(self): from google.cloud.spanner_v1.instance import Instance @@ -57,6 +58,7 @@ def test_constructor_defaults(self): self.assertIs(instance.configuration_name, None) self.assertEqual(instance.node_count, DEFAULT_NODE_COUNT) self.assertEqual(instance.display_name, self.INSTANCE_ID) + self.assertEqual(instance.labels, {}) def test_constructor_non_default(self): DISPLAY_NAME = "display_name" @@ -68,12 +70,14 @@ def test_constructor_non_default(self): configuration_name=self.CONFIG_NAME, node_count=self.NODE_COUNT, display_name=DISPLAY_NAME, + labels=self.LABELS, ) self.assertEqual(instance.instance_id, self.INSTANCE_ID) self.assertIs(instance._client, client) self.assertEqual(instance.configuration_name, self.CONFIG_NAME) self.assertEqual(instance.node_count, self.NODE_COUNT) self.assertEqual(instance.display_name, DISPLAY_NAME) + self.assertEqual(instance.labels, self.LABELS) def test_copy(self): DISPLAY_NAME = "display_name" @@ -145,6 +149,7 @@ def test_from_pb_success(self): name=self.INSTANCE_NAME, config=self.CONFIG_NAME, display_name=self.INSTANCE_ID, + labels=self.LABELS, ) klass = self._getTargetClass() @@ -153,6 +158,7 @@ def test_from_pb_success(self): self.assertEqual(instance._client, client) self.assertEqual(instance.instance_id, self.INSTANCE_ID) self.assertEqual(instance.configuration_name, self.CONFIG_NAME) + self.assertEqual(instance.labels, self.LABELS) def test_name_property(self): client = _Client(project=self.PROJECT) @@ -160,6 +166,14 @@ def test_name_property(self): instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME) self.assertEqual(instance.name, self.INSTANCE_NAME) + def test_labels_property(self): + client = _Client(project=self.PROJECT) + + instance = self._make_one( + self.INSTANCE_ID, client, self.CONFIG_NAME, labels=self.LABELS + ) + self.assertEqual(instance.labels, self.LABELS) + def test___eq__(self): client = object() instance1 = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME) @@ -231,6 +245,7 @@ def test_create_success(self): configuration_name=self.CONFIG_NAME, display_name=self.DISPLAY_NAME, node_count=self.NODE_COUNT, + labels=self.LABELS, ) future = instance.create() @@ -244,6 +259,7 @@ def test_create_success(self): self.assertEqual(instance.config, self.CONFIG_NAME) self.assertEqual(instance.display_name, self.DISPLAY_NAME) self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.labels, self.LABELS) self.assertEqual(metadata, [("google-cloud-resource-prefix", instance.name)]) def test_exists_instance_grpc_error(self): @@ -327,6 +343,7 @@ def test_reload_success(self): config=self.CONFIG_NAME, display_name=self.DISPLAY_NAME, node_count=self.NODE_COUNT, + labels=self.LABELS, ) api = client.instance_admin_api = _FauxInstanceAdminAPI( _get_instance_response=instance_pb @@ -338,6 +355,7 @@ def test_reload_success(self): self.assertEqual(instance.configuration_name, self.CONFIG_NAME) self.assertEqual(instance.node_count, self.NODE_COUNT) self.assertEqual(instance.display_name, self.DISPLAY_NAME) + self.assertEqual(instance.labels, self.LABELS) name, metadata = api._got_instance self.assertEqual(name, self.INSTANCE_NAME) @@ -371,7 +389,9 @@ def test_update_not_found(self): instance.update() instance, field_mask, metadata = api._updated_instance - self.assertEqual(field_mask.paths, ["config", "display_name", "node_count"]) + self.assertEqual( + field_mask.paths, ["config", "display_name", "node_count", "labels"] + ) self.assertEqual(instance.name, self.INSTANCE_NAME) self.assertEqual(instance.config, self.CONFIG_NAME) self.assertEqual(instance.display_name, self.INSTANCE_ID) @@ -390,6 +410,7 @@ def test_update_success(self): configuration_name=self.CONFIG_NAME, node_count=self.NODE_COUNT, display_name=self.DISPLAY_NAME, + labels=self.LABELS, ) future = instance.update() @@ -397,11 +418,14 @@ def test_update_success(self): self.assertIs(future, op_future) instance, field_mask, metadata = api._updated_instance - self.assertEqual(field_mask.paths, ["config", "display_name", "node_count"]) + self.assertEqual( + field_mask.paths, ["config", "display_name", "node_count", "labels"] + ) self.assertEqual(instance.name, self.INSTANCE_NAME) self.assertEqual(instance.config, self.CONFIG_NAME) self.assertEqual(instance.display_name, self.DISPLAY_NAME) self.assertEqual(instance.node_count, self.NODE_COUNT) + self.assertEqual(instance.labels, self.LABELS) self.assertEqual(metadata, [("google-cloud-resource-prefix", instance.name)]) def test_delete_grpc_error(self): @@ -673,6 +697,7 @@ def test_list_backups_w_options(self): ) def test_list_backup_operations_defaults(self): + from google.api_core.operation import Operation from google.cloud.spanner_admin_database_v1 import CreateBackupMetadata from google.cloud.spanner_admin_database_v1 import DatabaseAdminClient from google.cloud.spanner_admin_database_v1 import ListBackupOperationsRequest @@ -702,7 +727,7 @@ def test_list_backup_operations_defaults(self): api._transport.list_backup_operations ] = mock.Mock(return_value=operations_pb) - instance.list_backup_operations() + ops = instance.list_backup_operations() expected_metadata = ( ("google-cloud-resource-prefix", instance.name), @@ -714,8 +739,10 @@ def test_list_backup_operations_defaults(self): retry=mock.ANY, timeout=mock.ANY, ) + self.assertTrue(all([type(op) == Operation for op in ops])) def test_list_backup_operations_w_options(self): + from google.api_core.operation import Operation from google.cloud.spanner_admin_database_v1 import CreateBackupMetadata from google.cloud.spanner_admin_database_v1 import DatabaseAdminClient from google.cloud.spanner_admin_database_v1 import ListBackupOperationsRequest @@ -745,7 +772,7 @@ def test_list_backup_operations_w_options(self): api._transport.list_backup_operations ] = mock.Mock(return_value=operations_pb) - instance.list_backup_operations(filter_="filter", page_size=10) + ops = instance.list_backup_operations(filter_="filter", page_size=10) expected_metadata = ( ("google-cloud-resource-prefix", instance.name), @@ -759,8 +786,10 @@ def test_list_backup_operations_w_options(self): retry=mock.ANY, timeout=mock.ANY, ) + self.assertTrue(all([type(op) == Operation for op in ops])) def test_list_database_operations_defaults(self): + from google.api_core.operation import Operation from google.cloud.spanner_admin_database_v1 import CreateDatabaseMetadata from google.cloud.spanner_admin_database_v1 import DatabaseAdminClient from google.cloud.spanner_admin_database_v1 import ListDatabaseOperationsRequest @@ -803,7 +832,7 @@ def test_list_database_operations_defaults(self): api._transport.list_database_operations ] = mock.Mock(return_value=databases_pb) - instance.list_database_operations() + ops = instance.list_database_operations() expected_metadata = ( ("google-cloud-resource-prefix", instance.name), @@ -815,8 +844,10 @@ def test_list_database_operations_defaults(self): retry=mock.ANY, timeout=mock.ANY, ) + self.assertTrue(all([type(op) == Operation for op in ops])) def test_list_database_operations_w_options(self): + from google.api_core.operation import Operation from google.cloud.spanner_admin_database_v1 import DatabaseAdminClient from google.cloud.spanner_admin_database_v1 import ListDatabaseOperationsRequest from google.cloud.spanner_admin_database_v1 import ( @@ -864,7 +895,7 @@ def test_list_database_operations_w_options(self): api._transport.list_database_operations ] = mock.Mock(return_value=databases_pb) - instance.list_database_operations(filter_="filter", page_size=10) + ops = instance.list_database_operations(filter_="filter", page_size=10) expected_metadata = ( ("google-cloud-resource-prefix", instance.name), @@ -878,6 +909,7 @@ def test_list_database_operations_w_options(self): retry=mock.ANY, timeout=mock.ANY, ) + self.assertTrue(all([type(op) == Operation for op in ops])) def test_type_string_to_type_pb_hit(self): from google.cloud.spanner_admin_database_v1 import (