diff --git a/.coveragerc b/.coveragerc index 5b3f287a0f..2719524048 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,10 +2,10 @@ branch = True [report] -fail_under = 100 +fail_under = 99 show_missing = True omit = - google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py + .nox/* exclude_lines = # Re-enable the standard pragma pragma: NO COVER @@ -15,4 +15,4 @@ exclude_lines = # This is added at the module level as a safeguard for if someone # generates the code and tries to run it without pip installing. This # makes it virtually impossible to test properly. - except pkg_resources.DistributionNotFound + except pkg_resources.DistributionNotFound \ No newline at end of file diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index b703be9596..1e00173609 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -9,4 +9,3 @@ branchProtectionRules: - 'Kokoro' - 'cla/google' - 'Samples - Lint' - - 'Samples - Python 3.7' diff --git a/.gitignore b/.gitignore index b4243ced74..5555e7de6d 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ pip-log.txt # Unit test / coverage reports .coverage +.coverage.* .nox .cache .pytest_cache diff --git a/.kokoro/samples/python3.8/common.cfg b/.kokoro/samples/python3.8/common.cfg index 512c9ee399..21b411c8e1 100644 --- a/.kokoro/samples/python3.8/common.cfg +++ b/.kokoro/samples/python3.8/common.cfg @@ -19,6 +19,12 @@ env_vars: { value: "py-3.8" } +# Run tests located under tests/system +env_vars: { + key: "RUN_SYSTEM_TESTS" + value: "true" +} + env_vars: { key: "TRAMPOLINE_BUILD_FILE" value: "github/python-aiplatform/.kokoro/test-samples.sh" diff --git a/CHANGELOG.md b/CHANGELOG.md index be2d9a602f..8c81dd57c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +## [0.7.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.6.0...v0.7.0) (2021-04-14) + + +### Features + +* Add Custom Container Prediction support, move to single API endpoint ([#277](https://www.github.com/googleapis/python-aiplatform/issues/277)) ([ca7f6d6](https://www.github.com/googleapis/python-aiplatform/commit/ca7f6d64ea75349a841b53fe6ef6547942439e35)) +* Add initial Model Builder SDK samples ([#265](https://www.github.com/googleapis/python-aiplatform/issues/265)) ([1230dc6](https://www.github.com/googleapis/python-aiplatform/commit/1230dc68a34c5b747186d31a25d1b8f40bf7a97e)) +* Add list() method to all resource nouns ([#294](https://www.github.com/googleapis/python-aiplatform/issues/294)) ([3ec9386](https://www.github.com/googleapis/python-aiplatform/commit/3ec9386f8f766662c91922af66b8098ddfa1eb8f)) +* add support for multiple client versions, change aiplatform from compat.V1BETA1 to compat.V1 ([#290](https://www.github.com/googleapis/python-aiplatform/issues/290)) ([89e3212](https://www.github.com/googleapis/python-aiplatform/commit/89e321246b6223a2355947d8dbd0161b84523478)) +* Make aiplatform.Dataset private ([#296](https://www.github.com/googleapis/python-aiplatform/issues/296)) ([1f0d5f3](https://www.github.com/googleapis/python-aiplatform/commit/1f0d5f3e3f95ee5056545e9d4742b96e9380a22e)) +* parse project location when passed full resource name to get apis ([#297](https://www.github.com/googleapis/python-aiplatform/issues/297)) ([674227d](https://www.github.com/googleapis/python-aiplatform/commit/674227d2e7ed4a4a4e180213dc1178dde7d65a3a)) + + +### Bug Fixes + +* add quotes to logged snippet ([0ecd0a8](https://www.github.com/googleapis/python-aiplatform/commit/0ecd0a8bbc5a2fc645877d0eb3b930e1b03a270a)) +* make logging more informative during training ([#310](https://www.github.com/googleapis/python-aiplatform/issues/310)) ([9a4d991](https://www.github.com/googleapis/python-aiplatform/commit/9a4d99150a035b8dde7b4f9e72f25745af17b609)) +* remove TPU from accelerator test cases ([57f4fcf](https://www.github.com/googleapis/python-aiplatform/commit/57f4fcf7637467f6176436f6d2e1f6c8be909c4a)) + ## [0.6.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.5.1...v0.6.0) (2021-03-22) diff --git a/README.rst b/README.rst index 209b577ead..e0e66ce2da 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,14 @@ Python Client for Cloud AI Platform ================================================= +**Experimental** + +This is an Experimental release. Experiments are focused on validating a prototype. They are not guaranteed to be released and might be subject to backward-incompatible changes. They are not intended for production use or covered by any SLA, support obligation, or deprecation policy. They are covered by the `Pre-GA Offerings Terms`_ of the Google Cloud Platform Terms of Services. + +.. _Pre-GA Offerings Terms: https://cloud.google.com/terms/service-terms#1 + +---- + |beta| |pypi| |versions| diff --git a/docs/aiplatform.rst b/docs/aiplatform.rst new file mode 100644 index 0000000000..bf5cd4625b --- /dev/null +++ b/docs/aiplatform.rst @@ -0,0 +1,6 @@ +Google Cloud Aiplatform SDK +============================================= + +.. automodule:: google.cloud.aiplatform + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 765eb55989..031271a261 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,6 +7,7 @@ API Reference .. toctree:: :maxdepth: 2 + aiplatform aiplatform_v1/services aiplatform_v1/types diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index bb196e2c19..58eb824454 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -16,6 +16,52 @@ # from google.cloud.aiplatform import gapic +from google.cloud.aiplatform import explain +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.datasets import ( + ImageDataset, + TabularDataset, + TextDataset, + VideoDataset, +) +from google.cloud.aiplatform.models import Endpoint +from google.cloud.aiplatform.models import Model +from google.cloud.aiplatform.jobs import BatchPredictionJob +from google.cloud.aiplatform.training_jobs import ( + CustomTrainingJob, + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + AutoMLTabularTrainingJob, + AutoMLImageTrainingJob, + AutoMLTextTrainingJob, + AutoMLVideoTrainingJob, +) -__all__ = ("gapic",) +""" +Usage: +from google.cloud import aiplatform + +aiplatform.init(project='my_project') +""" +init = initializer.global_config.init + +__all__ = ( + "explain", + "gapic", + "init", + "AutoMLImageTrainingJob", + "AutoMLTabularTrainingJob", + "AutoMLTextTrainingJob", + "AutoMLVideoTrainingJob", + "BatchPredictionJob", + "CustomTrainingJob", + "CustomContainerTrainingJob", + "CustomPythonPackageTrainingJob", + "Endpoint", + "ImageDataset", + "Model", + "TabularDataset", + "TextDataset", + "VideoDataset", +) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py new file mode 100644 index 0000000000..4bb996c881 --- /dev/null +++ b/google/cloud/aiplatform/base.py @@ -0,0 +1,1021 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +from concurrent import futures +import datetime +import functools +import inspect +import logging +import sys +import threading +from typing import ( + Any, + Callable, + Dict, + List, + Iterable, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import proto + +from google.api_core import operation +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils + + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) + + +class Logger: + """Logging wrapper class with high level helper methods.""" + + def __init__(self, name: str = ""): + """Initializes logger with name. + + Args: + name (str): Name to associate with logger. + """ + self._logger = logging.getLogger(name) + + def log_create_with_lro( + self, + cls: Type["AiPlatformResourceNoun"], + lro: Optional[operation.Operation] = None, + ): + """Logs create event with LRO. + + Args: + cls (AiPlatformResourceNoune): + AI Platform Resource Noun class that is being created. + lro (operation.Operation): + Optional. Backing LRO for creation. + """ + self._logger.info(f"Creating {cls.__name__}") + + if lro: + self._logger.info( + f"Create {cls.__name__} backing LRO: {lro.operation.name}" + ) + + def log_create_complete( + self, + cls: Type["AiPlatformResourceNoun"], + resource: proto.Message, + variable_name: str, + ): + """Logs create event is complete. + + Will also include code snippet to instantiate resource in SDK. + + Args: + cls (AiPlatformResourceNoun): + AI Platform Resource Noun class that is being created. + resource (proto.Message): + AI Platform Resourc proto.Message + variable_name (str): Name of variable to use for code snippet + + """ + self._logger.info(f"{cls.__name__} created. Resource name: {resource.name}") + self._logger.info(f"To use this {cls.__name__} in another session:") + self._logger.info( + f"{variable_name} = aiplatform.{cls.__name__}('{resource.name}')" + ) + + def log_action_start_against_resource( + self, action: str, noun: str, resource_noun_obj: "AiPlatformResourceNoun" + ): + """Logs intention to start an action against a resource. + + Args: + action (str): Action to complete against the resource ie: "Deploying". Can be empty string. + noun (str): Noun the action acts on against the resource. Can be empty string. + resource_noun_obj (AiPlatformResourceNoun): + Resource noun object the action is acting against. + """ + self._logger.info( + f"{action} {resource_noun_obj.__class__.__name__} {noun}: {resource_noun_obj.resource_name}" + ) + + def log_action_started_against_resource_with_lro( + self, + action: str, + noun: str, + cls: Type["AiPlatformResourceNoun"], + lro: operation.Operation, + ): + """Logs an action started against a resource with lro. + + Args: + action (str): Action started against resource. ie: "Deploy". Can be empty string. + noun (str): Noun the action acts on against the resource. Can be empty string. + cls (AiPlatformResourceNoun): + Resource noun object the action is acting against. + lro (operation.Operation): Backing LRO for action. + """ + self._logger.info( + f"{action} {cls.__name__} {noun} backing LRO: {lro.operation.name}" + ) + + def log_action_completed_against_resource( + self, noun: str, action: str, resource_noun_obj: "AiPlatformResourceNoun" + ): + """Logs action completed against resource. + + Args: + noun (str): Noun the action acts on against the resource. Can be empty string. + action (str): Action started against resource. ie: "Deployed". Can be empty string. + resource_noun_obj (AiPlatformResourceNoun): + Resource noun object the action is acting against + """ + self._logger.info( + f"{resource_noun_obj.__class__.__name__} {noun} {action}. Resource name: {resource_noun_obj.resource_name}" + ) + + def __getattr__(self, attr: str): + """Forward remainder of logging to underlying logger.""" + return getattr(self._logger, attr) + + +_LOGGER = Logger(__name__) + + +class FutureManager(metaclass=abc.ABCMeta): + """Tracks concurrent futures against this object.""" + + def __init__(self): + self.__latest_future_lock = threading.Lock() + + # Always points to the latest future. All submitted futures will always + # form a dependency on the latest future. + self.__latest_future = None + + # Caches Exception of any executed future. Once one exception occurs + # all additional futures should fail and any additional invocations will block. + self._exception = None + + def _raise_future_exception(self): + """Raises exception if one of the object's futures has raised.""" + with self.__latest_future_lock: + if self._exception: + raise self._exception + + def _complete_future(self, future: futures.Future): + """Checks for exception of future and removes the pointer if it's still latest. + + Args: + future (futures.Future): Required. A future to complete. + """ + + with self.__latest_future_lock: + try: + future.result() # raises + except Exception as e: + self._exception = e + + if self.__latest_future is future: + self.__latest_future = None + + def _are_futures_done(self) -> bool: + """Helper method to check to all futures are complete. + + Returns: + True if no latest future. + """ + with self.__latest_future_lock: + return self.__latest_future is None + + def wait(self): + """Helper method to that blocks until all futures are complete.""" + future = self.__latest_future + if future: + futures.wait([future], return_when=futures.FIRST_EXCEPTION) + + self._raise_future_exception() + + @property + def _latest_future(self) -> Optional[futures.Future]: + """Get the latest future if it exists""" + with self.__latest_future_lock: + return self.__latest_future + + @_latest_future.setter + def _latest_future(self, future: Optional[futures.Future]): + """Optionally set the latest future and add a complete_future callback.""" + with self.__latest_future_lock: + self.__latest_future = future + if future: + future.add_done_callback(self._complete_future) + + def _submit( + self, + method: Callable[..., Any], + args: Sequence[Any], + kwargs: Dict[str, Any], + additional_dependencies: Optional[Sequence[futures.Future]] = None, + callbacks: Optional[Sequence[Callable[[futures.Future], Any]]] = None, + internal_callbacks: Iterable[Callable[[Any], Any]] = None, + ) -> futures.Future: + """Submit a method as a future against this object. + + Args: + method (Callable): Required. The method to submit. + args (Sequence): Required. The arguments to call the method with. + kwargs (dict): Required. The keyword arguments to call the method with. + additional_dependencies (Optional[Sequence[futures.Future]]): + Optional. Additional dependent futures to wait on before executing + method. Note: No validation is done on the dependencies. + callbacks (Optional[Sequence[Callable[[futures.Future], Any]]]): + Optional. Additional Future callbacks to execute once this created + Future is complete. + + Returns: + future (Future): Future of the submitted method call. + """ + + def wait_for_dependencies_and_invoke( + deps: Sequence[futures.Future], + method: Callable[..., Any], + args: Sequence[Any], + kwargs: Dict[str, Any], + internal_callbacks: Iterable[Callable[[Any], Any]], + ) -> Any: + """Wrapper method to wait on any dependencies before submitting method. + + Args: + deps (Sequence[futures.Future]): + Required. Dependent futures to wait on before executing method. + Note: No validation is done on the dependencies. + method (Callable): Required. The method to submit. + args (Sequence[Any]): Required. The arguments to call the method with. + kwargs (Dict[str, Any]): + Required. The keyword arguments to call the method with. + internal_callbacks: (Callable[[Any], Any]): + Callbacks that take the result of method. + + """ + + for future in set(deps): + future.result() + + result = method(*args, **kwargs) + + # call callbacks from within future + if internal_callbacks: + for callback in internal_callbacks: + callback(result) + + return result + + # Retrieves any dependencies from arguments. + deps = [ + arg._latest_future + for arg in list(args) + list(kwargs.values()) + if isinstance(arg, FutureManager) + ] + + # Retrieves exceptions and raises + # if any upstream dependency has an exception + exceptions = [ + arg._exception + for arg in list(args) + list(kwargs.values()) + if isinstance(arg, FutureManager) and arg._exception + ] + + if exceptions: + raise exceptions[0] + + # filter out objects that do not have pending tasks + deps = [dep for dep in deps if dep] + + if additional_dependencies: + deps.extend(additional_dependencies) + + with self.__latest_future_lock: + + # form a dependency on the latest future of this object + if self.__latest_future: + deps.append(self.__latest_future) + + self.__latest_future = initializer.global_pool.submit( + wait_for_dependencies_and_invoke, + deps=deps, + method=method, + args=args, + kwargs=kwargs, + internal_callbacks=internal_callbacks, + ) + + future = self.__latest_future + + # Clean up callback captures exception as well as removes future. + # May execute immediately and take lock. + + future.add_done_callback(self._complete_future) + + if callbacks: + for c in callbacks: + future.add_done_callback(c) + + return future + + @classmethod + @abc.abstractmethod + def _empty_constructor(cls) -> "FutureManager": + """Should construct object with all non FutureManager attributes as None""" + pass + + @abc.abstractmethod + def _sync_object_with_future_result(self, result: "FutureManager"): + """Should sync the object from _empty_constructor with result of future.""" + + def __repr__(self) -> str: + if self._exception: + return f"{object.__repr__(self)} failed with {str(self._exception)}" + + if self.__latest_future: + return f"{object.__repr__(self)} is waiting for upstream dependencies to complete." + + return object.__repr__(self) + + +class AiPlatformResourceNoun(metaclass=abc.ABCMeta): + """Base class the AI Platform resource nouns. + + Subclasses require two class attributes: + + client_class: The client to instantiate to interact with this resource noun. + _is_client_prediction_client: Flag to indicate if the client requires a prediction endpoint. + + Subclass is required to populate private attribute _gca_resource which is the + service representation of the resource noun. + """ + + @property + @classmethod + @abc.abstractmethod + def client_class(cls) -> Type[utils.AiPlatformServiceClientWithOverride]: + """Client class required to interact with resource with optional overrides.""" + pass + + @property + @classmethod + @abc.abstractmethod + def _is_client_prediction_client(cls) -> bool: + """Flag to indicate whether to use prediction endpoint with client.""" + pass + + @property + @abc.abstractmethod + def _getter_method(cls) -> str: + """Name of getter method of client class for retrieving the resource.""" + pass + + @property + @abc.abstractmethod + def _delete_method(cls) -> str: + """Name of delete method of client class for deleting the resource.""" + pass + + @property + @abc.abstractmethod + def _resource_noun(cls) -> str: + """Resource noun""" + pass + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, + ): + """Initializes class with project, location, and api_client. + + Args: + project(str): Project of the resource noun. + location(str): The location of the resource noun. + credentials(google.auth.crendentials.Crendentials): Optional custom + credentials to use when accessing interacting with resource noun. + resource_name(str): A fully-qualified resource name or ID. + """ + + if resource_name: + project, location = self._get_and_validate_project_location( + resource_name=resource_name, project=project, location=location + ) + + self.project = project or initializer.global_config.project + self.location = location or initializer.global_config.location + self.credentials = credentials or initializer.global_config.credentials + + self.api_client = self._instantiate_client(self.location, self.credentials) + + @classmethod + def _instantiate_client( + cls, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> utils.AiPlatformServiceClientWithOverride: + """Helper method to instantiate service client for resource noun. + + Args: + location (str): The location of the resource noun. + credentials (google.auth.credentials.Credentials): + Optional custom credentials to use when accessing interacting with + resource noun. + Returns: + client (utils.AiPlatformServiceClientWithOverride): + Initialized service client for this service noun with optional overrides. + """ + return initializer.global_config.create_client( + client_class=cls.client_class, + credentials=credentials, + location_override=location, + prediction_client=cls._is_client_prediction_client, + ) + + def _get_and_validate_project_location( + self, + resource_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + ) -> Tuple: + + """Validate the project and location for the resource. + + Args: + resource_name(str): Required. A fully-qualified resource name or ID. + project(str): Project of the resource noun. + location(str): The location of the resource noun. + + Raises: + RuntimeError if location is different from resource location + """ + + if not project and not location: + return project, location + + fields = utils.extract_fields_from_resource_name( + resource_name, self._resource_noun + ) + if not fields: + return project, location + + if location and fields.location != location: + raise RuntimeError( + f"location {location} is provided, but different from " + f"the resource location {fields.location}" + ) + + return fields.project, fields.location + + def _get_gca_resource(self, resource_name: str) -> proto.Message: + """Returns GAPIC service representation of client class resource.""" + """ + Args: + resource_name (str): + Required. A fully-qualified resource name or ID. + """ + + resource_name = utils.full_resource_name( + resource_name=resource_name, + resource_noun=self._resource_noun, + project=self.project, + location=self.location, + ) + + return getattr(self.api_client, self._getter_method)(name=resource_name) + + def _sync_gca_resource(self): + """Sync GAPIC service representation of client class resource.""" + + self._gca_resource = self._get_gca_resource(resource_name=self.resource_name) + + @property + def name(self) -> str: + """Name of this resource.""" + return self._gca_resource.name.split("/")[-1] + + @property + def resource_name(self) -> str: + """Full qualified resource name.""" + return self._gca_resource.name + + @property + def display_name(self) -> str: + """Display name of this resource.""" + return self._gca_resource.display_name + + @property + def create_time(self) -> datetime.datetime: + """Time this resource was created.""" + return self._gca_resource.create_time + + @property + def update_time(self) -> datetime.datetime: + """Time this resource was last updated.""" + self._sync_gca_resource() + return self._gca_resource.update_time + + def __repr__(self) -> str: + return f"{object.__repr__(self)} \nresource name: {self.resource_name}" + + +def optional_sync( + construct_object_on_arg: Optional[str] = None, + return_input_arg: Optional[str] = None, + bind_future_to_self: bool = True, +): + """Decorator for AiPlatformResourceNounWithFutureManager with optional sync support. + + Methods with this decorator should include a "sync" argument that defaults to + True. If called with sync=False this decorator will launch the method as a + concurrent Future in a separate Thread. + + Note that this is only robust enough to support our current end to end patterns + and may not be suitable for new patterns. + + Args: + construct_object_on_arg (str): + Optional. If provided, will only construct output object if arg is present. + Example: If custom training does not produce a model. + return_input_arg (str): + Optional. If provided will return passed in argument instead of + constructing. + Example: Model.deploy(Endpoint) returns the passed in Endpoint + bind_future_to_self (bool): + Whether to add this future to the calling object. + Example: Model.deploy(Endpoint) would be set to False because we only + want the deployment Future to be associated with Endpoint. + """ + + def optional_run_in_thread(method: Callable[..., Any]): + """Optionally run this method concurrently in separate Thread. + + Args: + method (Callable[..., Any]): Method to optionally run in separate Thread. + """ + + @functools.wraps(method) + def wrapper(*args, **kwargs): + """Wraps method.""" + sync = kwargs.pop("sync", True) + bound_args = inspect.signature(method).bind(*args, **kwargs) + self = bound_args.arguments.get("self") + calling_object_latest_future = None + + # check to see if this object has any exceptions + if self: + calling_object_latest_future = self._latest_future + self._raise_future_exception() + + # if sync then wait for any Futures to complete and execute + if sync: + if self: + self.wait() + return method(*args, **kwargs) + + # callbacks to call within the Future (in same Thread) + internal_callbacks = [] + # callbacks to add to the Future (may or may not be in same Thread) + callbacks = [] + # additional Future dependencies to capture + dependencies = [] + + # all methods should have type signatures + return_type = get_annotation_class( + inspect.getfullargspec(method).annotations["return"] + ) + + # is a classmethod that creates the object and returns it + if args and inspect.isclass(args[0]): + # assumes classmethod is our resource noun + returned_object = args[0]._empty_constructor() + self = returned_object + + else: # instance method + + # object produced by the method + returned_object = bound_args.arguments.get(return_input_arg) + + # if we're returning an input object + if returned_object and returned_object is not self: + + # make sure the input object doesn't have any exceptions + # from previous futures + returned_object._raise_future_exception() + + # if the future will be associated with both the returned object + # and calling object then we need to add additional callback + # to remove the future from the returned object + + # if we need to construct a new empty returned object + should_construct = not returned_object and bound_args.arguments.get( + construct_object_on_arg, not construct_object_on_arg + ) + + if should_construct: + if return_type is not None: + returned_object = return_type._empty_constructor() + + # if the future will be associated with both the returned object + # and calling object then we need to add additional callback + # to remove the future from the returned object + if returned_object and bind_future_to_self: + callbacks.append(returned_object._complete_future) + + if returned_object: + # sync objects after future completes + internal_callbacks.append( + returned_object._sync_object_with_future_result + ) + + # If the future is not associated with the calling object + # then the return object future needs to form a dependency on the + # the latest future in the calling object. + if not bind_future_to_self: + if calling_object_latest_future: + dependencies.append(calling_object_latest_future) + self = returned_object + + future = self._submit( + method=method, + callbacks=callbacks, + internal_callbacks=internal_callbacks, + additional_dependencies=dependencies, + args=[], + kwargs=bound_args.arguments, + ) + + # if the calling object is the one that submitted then add it's future + # to the returned object + if returned_object and returned_object is not self: + returned_object._latest_future = future + + return returned_object + + return wrapper + + return optional_run_in_thread + + +class AiPlatformResourceNounWithFutureManager(AiPlatformResourceNoun, FutureManager): + """Allows optional asynchronous calls to this AI Platform Resource Nouns.""" + + def __init__( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, + ): + """Initializes class with project, location, and api_client. + + Args: + project (str): Optional. Project of the resource noun. + location (str): Optional. The location of the resource noun. + credentials(google.auth.crendentials.Crendentials): + Optional. custom credentials to use when accessing interacting with + resource noun. + resource_name(str): A fully-qualified resource name or ID. + """ + AiPlatformResourceNoun.__init__( + self, + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, + ) + FutureManager.__init__(self) + + @classmethod + def _empty_constructor( + cls, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + resource_name: Optional[str] = None, + ) -> "AiPlatformResourceNounWithFutureManager": + """Initializes with all attributes set to None. + + The attributes should be populated after a future is complete. This allows + scheduling of additional API calls before the resource is created. + + Args: + project (str): Optional. Project of the resource noun. + location (str): Optional. The location of the resource noun. + credentials(google.auth.crendentials.Crendentials): + Optional. custom credentials to use when accessing interacting with + resource noun. + resource_name(str): A fully-qualified resource name or ID. + Returns: + An instance of this class with attributes set to None. + """ + self = cls.__new__(cls) + AiPlatformResourceNoun.__init__( + self, + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, + ) + FutureManager.__init__(self) + self._gca_resource = None + return self + + def _sync_object_with_future_result( + self, result: "AiPlatformResourceNounWithFutureManager" + ): + """Populates attributes from a Future result to this object. + + Args: + result: AiPlatformResourceNounWithFutureManager + Required. Result of future with same type as this object. + """ + sync_attributes = [ + "project", + "location", + "api_client", + "_gca_resource", + "credentials", + ] + optional_sync_attributes = ["_prediction_client"] + + for attribute in sync_attributes: + setattr(self, attribute, getattr(result, attribute)) + + for attribute in optional_sync_attributes: + value = getattr(result, attribute, None) + if value: + setattr(self, attribute, value) + + def _construct_sdk_resource_from_gapic( + self, + gapic_resource: proto.Message, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> AiPlatformResourceNoun: + """Given a GAPIC resource object, return the SDK representation. + + Args: + gapic_resource (proto.Message): + A GAPIC representation of an AI Platform resource, usually + retrieved by a get_* or in a list_* API call. + project (str): + Optional. Project to construct SDK object from. If not set, + project set in aiplatform.init will be used. + location (str): + Optional. Location to construct SDK object from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to construct SDK object. + Overrides credentials set in aiplatform.init. + + Returns: + AiPlatformResourceNoun: + An initialized SDK object that represents GAPIC type. + """ + sdk_resource = self._empty_constructor( + project=project, location=location, credentials=credentials + ) + sdk_resource._gca_resource = gapic_resource + return sdk_resource + + # TODO(b/144545165): Improve documentation for list filtering once available + # TODO(b/184910159): Expose `page_size` field in list method + @classmethod + def _list( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + self = cls._empty_constructor( + project=project, location=location, credentials=credentials + ) + + # Fetch credentials once and re-use for all `_empty_constructor()` calls + creds = initializer.global_config.credentials + + resource_list_method = getattr(self.api_client, self._list_method) + + list_request = { + "parent": initializer.global_config.common_location_path( + project=project, location=location + ), + "filter": filter, + } + + if order_by: + list_request["order_by"] = order_by + + resource_list = resource_list_method(request=list_request) or [] + + return [ + self._construct_sdk_resource_from_gapic( + gapic_resource, project=project, location=location, credentials=creds + ) + for gapic_resource in resource_list + if cls_filter(gapic_resource) + ] + + @classmethod + def _list_with_local_order( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + Provides client-side sorting when a list API doesn't support `order_by`. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + + li = cls._list( + cls_filter=cls_filter, + filter=filter, + order_by=None, # This method will handle the ordering locally + project=project, + location=location, + credentials=credentials, + ) + + desc = "desc" in order_by + order_by = order_by.replace("desc", "") + order_by = order_by.split(",") + + li.sort( + key=lambda x: tuple(getattr(x, field.strip()) for field in order_by), + reverse=desc, + ) + + return li + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """List all instances of this AI Platform Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + aiplatform.Model.list(order_by="create_time desc, display_name") + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ + + return cls._list( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + @optional_sync() + def delete(self, sync: bool = True) -> None: + """Deletes this AI Platform resource. WARNING: This deletion is permament. + + Args: + sync (bool): + Whether to execute this deletion synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + _LOGGER.log_action_start_against_resource("Deleting", "", self) + lro = getattr(self.api_client, self._delete_method)(name=self.resource_name) + _LOGGER.log_action_started_against_resource_with_lro( + "Delete", "", self.__class__, lro + ) + lro.result() + _LOGGER.log_action_completed_against_resource("deleted.", "", self) + + def __repr__(self) -> str: + if self._gca_resource: + return AiPlatformResourceNoun.__repr__(self) + + return FutureManager.__repr__(self) + + +def get_annotation_class(annotation: type) -> type: + """Helper method to retrieve type annotation. + + Args: + annotation (type): Type hint + """ + # typing.Optional + if getattr(annotation, "__origin__", None) is Union: + return annotation.__args__[0] + else: + return annotation diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py new file mode 100644 index 0000000000..36d805c6cb --- /dev/null +++ b/google/cloud/aiplatform/compat/__init__.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat import services +from google.cloud.aiplatform.compat import types + +V1BETA1 = "v1beta1" +V1 = "v1" + +DEFAULT_VERSION = V1 + +if DEFAULT_VERSION == V1BETA1: + + services.dataset_service_client = services.dataset_service_client_v1beta1 + services.endpoint_service_client = services.endpoint_service_client_v1beta1 + services.job_service_client = services.job_service_client_v1beta1 + services.model_service_client = services.model_service_client_v1beta1 + services.pipeline_service_client = services.pipeline_service_client_v1beta1 + services.prediction_service_client = services.prediction_service_client_v1beta1 + services.specialist_pool_service_client = ( + services.specialist_pool_service_client_v1beta1 + ) + + types.accelerator_type = types.accelerator_type_v1beta1 + types.annotation = types.annotation_v1beta1 + types.annotation_spec = types.annotation_spec_v1beta1 + types.batch_prediction_job = types.batch_prediction_job_v1beta1 + types.completion_stats = types.completion_stats_v1beta1 + types.custom_job = types.custom_job_v1beta1 + types.data_item = types.data_item_v1beta1 + types.data_labeling_job = types.data_labeling_job_v1beta1 + types.dataset = types.dataset_v1beta1 + types.dataset_service = types.dataset_service_v1beta1 + types.deployed_model_ref = types.deployed_model_ref_v1beta1 + types.encryption_spec = types.encryption_spec_v1beta1 + types.endpoint = types.endpoint_v1beta1 + types.endpoint_service = types.endpoint_service_v1beta1 + types.env_var = types.env_var_v1beta1 + types.explanation = types.explanation_v1beta1 + types.explanation_metadata = types.explanation_metadata_v1beta1 + types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1 + types.io = types.io_v1beta1 + types.job_service = types.job_service_v1beta1 + types.job_state = types.job_state_v1beta1 + types.machine_resources = types.machine_resources_v1beta1 + types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1beta1 + types.model = types.model_v1beta1 + types.model_evaluation = types.model_evaluation_v1beta1 + types.model_evaluation_slice = types.model_evaluation_slice_v1beta1 + types.model_service = types.model_service_v1beta1 + types.operation = types.operation_v1beta1 + types.pipeline_service = types.pipeline_service_v1beta1 + types.pipeline_state = types.pipeline_state_v1beta1 + types.prediction_service = types.prediction_service_v1beta1 + types.specialist_pool = types.specialist_pool_v1beta1 + types.specialist_pool_service = types.specialist_pool_service_v1beta1 + types.training_pipeline = types.training_pipeline_v1beta1 + +if DEFAULT_VERSION == V1: + + services.dataset_service_client = services.dataset_service_client_v1 + services.endpoint_service_client = services.endpoint_service_client_v1 + services.job_service_client = services.job_service_client_v1 + services.model_service_client = services.model_service_client_v1 + services.pipeline_service_client = services.pipeline_service_client_v1 + services.prediction_service_client = services.prediction_service_client_v1 + services.specialist_pool_service_client = services.specialist_pool_service_client_v1 + + types.accelerator_type = types.accelerator_type_v1 + types.annotation = types.annotation_v1 + types.annotation_spec = types.annotation_spec_v1 + types.batch_prediction_job = types.batch_prediction_job_v1 + types.completion_stats = types.completion_stats_v1 + types.custom_job = types.custom_job_v1 + types.data_item = types.data_item_v1 + types.data_labeling_job = types.data_labeling_job_v1 + types.dataset = types.dataset_v1 + types.dataset_service = types.dataset_service_v1 + types.deployed_model_ref = types.deployed_model_ref_v1 + types.encryption_spec = types.encryption_spec_v1 + types.endpoint = types.endpoint_v1 + types.endpoint_service = types.endpoint_service_v1 + types.env_var = types.env_var_v1 + types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1 + types.io = types.io_v1 + types.job_service = types.job_service_v1 + types.job_state = types.job_state_v1 + types.machine_resources = types.machine_resources_v1 + types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1 + types.model = types.model_v1 + types.model_evaluation = types.model_evaluation_v1 + types.model_evaluation_slice = types.model_evaluation_slice_v1 + types.model_service = types.model_service_v1 + types.operation = types.operation_v1 + types.pipeline_service = types.pipeline_service_v1 + types.pipeline_state = types.pipeline_state_v1 + types.prediction_service = types.prediction_service_v1 + types.specialist_pool = types.specialist_pool_v1 + types.specialist_pool_service = types.specialist_pool_service_v1 + types.training_pipeline = types.training_pipeline_v1 + +__all__ = ( + DEFAULT_VERSION, + V1BETA1, + V1, + services, + types, +) diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py new file mode 100644 index 0000000000..0888c27fbb --- /dev/null +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + client as dataset_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + client as pipeline_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + client as prediction_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + client as specialist_pool_service_client_v1beta1, +) + +from google.cloud.aiplatform_v1.services.dataset_service import ( + client as dataset_service_client_v1, +) +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client_v1, +) +from google.cloud.aiplatform_v1.services.job_service import ( + client as job_service_client_v1, +) +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client_v1, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client_v1, +) +from google.cloud.aiplatform_v1.services.prediction_service import ( + client as prediction_service_client_v1, +) +from google.cloud.aiplatform_v1.services.specialist_pool_service import ( + client as specialist_pool_service_client_v1, +) + +__all__ = ( + # v1 + dataset_service_client_v1, + endpoint_service_client_v1, + job_service_client_v1, + model_service_client_v1, + pipeline_service_client_v1, + prediction_service_client_v1, + specialist_pool_service_client_v1, + # v1beta1 + dataset_service_client_v1beta1, + endpoint_service_client_v1beta1, + job_service_client_v1beta1, + model_service_client_v1beta1, + pipeline_service_client_v1beta1, + prediction_service_client_v1beta1, + specialist_pool_service_client_v1beta1, +) diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py new file mode 100644 index 0000000000..d03e0d2f3a --- /dev/null +++ b/google/cloud/aiplatform/compat/types/__init__.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform_v1beta1.types import ( + accelerator_type as accelerator_type_v1beta1, + annotation as annotation_v1beta1, + annotation_spec as annotation_spec_v1beta1, + batch_prediction_job as batch_prediction_job_v1beta1, + completion_stats as completion_stats_v1beta1, + custom_job as custom_job_v1beta1, + data_item as data_item_v1beta1, + data_labeling_job as data_labeling_job_v1beta1, + dataset as dataset_v1beta1, + dataset_service as dataset_service_v1beta1, + deployed_model_ref as deployed_model_ref_v1beta1, + encryption_spec as encryption_spec_v1beta1, + endpoint as endpoint_v1beta1, + endpoint_service as endpoint_service_v1beta1, + env_var as env_var_v1beta1, + explanation as explanation_v1beta1, + explanation_metadata as explanation_metadata_v1beta1, + hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1, + io as io_v1beta1, + job_service as job_service_v1beta1, + job_state as job_state_v1beta1, + machine_resources as machine_resources_v1beta1, + manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1, + model as model_v1beta1, + model_evaluation as model_evaluation_v1beta1, + model_evaluation_slice as model_evaluation_slice_v1beta1, + model_service as model_service_v1beta1, + operation as operation_v1beta1, + pipeline_service as pipeline_service_v1beta1, + pipeline_state as pipeline_state_v1beta1, + prediction_service as prediction_service_v1beta1, + specialist_pool as specialist_pool_v1beta1, + specialist_pool_service as specialist_pool_service_v1beta1, + training_pipeline as training_pipeline_v1beta1, +) +from google.cloud.aiplatform_v1.types import ( + accelerator_type as accelerator_type_v1, + annotation as annotation_v1, + annotation_spec as annotation_spec_v1, + batch_prediction_job as batch_prediction_job_v1, + completion_stats as completion_stats_v1, + custom_job as custom_job_v1, + data_item as data_item_v1, + data_labeling_job as data_labeling_job_v1, + dataset as dataset_v1, + dataset_service as dataset_service_v1, + deployed_model_ref as deployed_model_ref_v1, + encryption_spec as encryption_spec_v1, + endpoint as endpoint_v1, + endpoint_service as endpoint_service_v1, + env_var as env_var_v1, + hyperparameter_tuning_job as hyperparameter_tuning_job_v1, + io as io_v1, + job_service as job_service_v1, + job_state as job_state_v1, + machine_resources as machine_resources_v1, + manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1, + model as model_v1, + model_evaluation as model_evaluation_v1, + model_evaluation_slice as model_evaluation_slice_v1, + model_service as model_service_v1, + operation as operation_v1, + pipeline_service as pipeline_service_v1, + pipeline_state as pipeline_state_v1, + prediction_service as prediction_service_v1, + specialist_pool as specialist_pool_v1, + specialist_pool_service as specialist_pool_service_v1, + training_pipeline as training_pipeline_v1, +) + +__all__ = ( + # v1 + accelerator_type_v1, + annotation_v1, + annotation_spec_v1, + batch_prediction_job_v1, + completion_stats_v1, + custom_job_v1, + data_item_v1, + data_labeling_job_v1, + dataset_v1, + dataset_service_v1, + deployed_model_ref_v1, + encryption_spec_v1, + endpoint_v1, + endpoint_service_v1, + env_var_v1, + hyperparameter_tuning_job_v1, + io_v1, + job_service_v1, + job_state_v1, + machine_resources_v1, + manual_batch_tuning_parameters_v1, + model_v1, + model_evaluation_v1, + model_evaluation_slice_v1, + model_service_v1, + operation_v1, + pipeline_service_v1, + pipeline_state_v1, + prediction_service_v1, + specialist_pool_v1, + specialist_pool_service_v1, + training_pipeline_v1, + # v1beta1 + accelerator_type_v1beta1, + annotation_v1beta1, + annotation_spec_v1beta1, + batch_prediction_job_v1beta1, + completion_stats_v1beta1, + custom_job_v1beta1, + data_item_v1beta1, + data_labeling_job_v1beta1, + dataset_v1beta1, + dataset_service_v1beta1, + deployed_model_ref_v1beta1, + encryption_spec_v1beta1, + endpoint_v1beta1, + endpoint_service_v1beta1, + env_var_v1beta1, + explanation_v1beta1, + explanation_metadata_v1beta1, + hyperparameter_tuning_job_v1beta1, + io_v1beta1, + job_service_v1beta1, + job_state_v1beta1, + machine_resources_v1beta1, + manual_batch_tuning_parameters_v1beta1, + model_v1beta1, + model_evaluation_v1beta1, + model_evaluation_slice_v1beta1, + model_service_v1beta1, + operation_v1beta1, + pipeline_service_v1beta1, + pipeline_state_v1beta1, + prediction_service_v1beta1, + specialist_pool_v1beta1, + specialist_pool_service_v1beta1, + training_pipeline_v1beta1, +) diff --git a/google/cloud/aiplatform/constants.py b/google/cloud/aiplatform/constants.py new file mode 100644 index 0000000000..62c28009c2 --- /dev/null +++ b/google/cloud/aiplatform/constants.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DEFAULT_REGION = "us-central1" +SUPPORTED_REGIONS = ("us-central1", "europe-west4", "asia-east1") +API_BASE_PATH = "aiplatform.googleapis.com" + +# Batch Prediction +BATCH_PREDICTION_INPUT_STORAGE_FORMATS = ( + "jsonl", + "csv", + "tf-record", + "tf-record-gzip", + "bigquery", + "file-list", +) +BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS = ("jsonl", "csv", "bigquery") + +MOBILE_TF_MODEL_TYPES = { + "MOBILE_TF_LOW_LATENCY_1", + "MOBILE_TF_VERSATILE_1", + "MOBILE_TF_HIGH_ACCURACY_1", +} + +# TODO(b/177079208): Use EPCL Enums for validating Model Types +# Defined by gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_* +# Format: "prediction_type": set() of model_type's +# +# NOTE: When adding a new prediction_type's, ensure it fits the pattern +# "automl_image_{prediction_type}_*" used by the YAML schemas on GCS +AUTOML_IMAGE_PREDICTION_MODEL_TYPES = { + "classification": {"CLOUD"} | MOBILE_TF_MODEL_TYPES, + "object_detection": {"CLOUD_HIGH_ACCURACY_1", "CLOUD_LOW_LATENCY_1"} + | MOBILE_TF_MODEL_TYPES, +} + +AUTOML_VIDEO_PREDICTION_MODEL_TYPES = { + "classification": {"CLOUD"} | {"MOBILE_VERSATILE_1"}, + "action_recognition": {"CLOUD"} | {"MOBILE_VERSATILE_1"}, + "object_tracking": {"CLOUD"} + | { + "MOBILE_VERSATILE_1", + "MOBILE_CORAL_VERSATILE_1", + "MOBILE_CORAL_LOW_LATENCY_1", + "MOBILE_JETSON_VERSATILE_1", + "MOBILE_JETSON_LOW_LATENCY_1", + }, +} diff --git a/google/cloud/aiplatform/datasets/__init__.py b/google/cloud/aiplatform/datasets/__init__.py new file mode 100644 index 0000000000..57e2bad45d --- /dev/null +++ b/google/cloud/aiplatform/datasets/__init__.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.datasets.dataset import _Dataset +from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset +from google.cloud.aiplatform.datasets.image_dataset import ImageDataset +from google.cloud.aiplatform.datasets.text_dataset import TextDataset +from google.cloud.aiplatform.datasets.video_dataset import VideoDataset + + +__all__ = ( + "_Dataset", + "TabularDataset", + "ImageDataset", + "TextDataset", + "VideoDataset", +) diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py new file mode 100644 index 0000000000..eefd1b04fd --- /dev/null +++ b/google/cloud/aiplatform/datasets/_datasources.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +from typing import Optional, Dict, Sequence, Union +from google.cloud.aiplatform import schema + +from google.cloud.aiplatform.compat.types import ( + io as gca_io, + dataset as gca_dataset, +) + + +class Datasource(abc.ABC): + """An abstract class that sets dataset_metadata""" + + @property + @abc.abstractmethod + def dataset_metadata(self): + """Dataset Metadata.""" + pass + + +class DatasourceImportable(abc.ABC): + """An abstract class that sets import_data_config""" + + @property + @abc.abstractmethod + def import_data_config(self): + """Import Data Config.""" + pass + + +class TabularDatasource(Datasource): + """Datasource for creating a tabular dataset for AI Platform""" + + def __init__( + self, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + ): + """Creates a tabular datasource + + Args: + gcs_source (Union[str, Sequence[str]]): + Cloud Storage URI of one or more files. Only CSV files are supported. + The first line of the CSV file is used as the header. + If there are multiple files, the header is the first line of + the lexicographically first file, the other files must either + contain the exact same header or omit the header. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + The URI of a BigQuery table. + example: + "bq://project.dataset.table_name" + + Raises: + ValueError if source configuration is not valid. + """ + + dataset_metadata = None + + if gcs_source and isinstance(gcs_source, str): + gcs_source = [gcs_source] + + if gcs_source and bq_source: + raise ValueError("Only one of gcs_source or bq_source can be set.") + + if not any([gcs_source, bq_source]): + raise ValueError("One of gcs_source or bq_source must be set.") + + if gcs_source: + dataset_metadata = {"input_config": {"gcs_source": {"uri": gcs_source}}} + elif bq_source: + dataset_metadata = {"input_config": {"bigquery_source": {"uri": bq_source}}} + + self._dataset_metadata = dataset_metadata + + @property + def dataset_metadata(self) -> Optional[Dict]: + """Dataset Metadata.""" + return self._dataset_metadata + + +class NonTabularDatasource(Datasource): + """Datasource for creating an empty non-tabular dataset for AI Platform""" + + @property + def dataset_metadata(self) -> Optional[Dict]: + return None + + +class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable): + """Datasource for creating a non-tabular dataset for AI Platform and importing data to the dataset""" + + def __init__( + self, + gcs_source: Union[str, Sequence[str]], + import_schema_uri: str, + data_item_labels: Optional[Dict] = None, + ): + """Creates a non-tabular datasource + + Args: + gcs_source (Union[str, Sequence[str]]): + Required. The Google Cloud Storage location for the input content. + Google Cloud Storage URI(-s) to the input file(s). May contain + wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + """ + super().__init__() + self._gcs_source = [gcs_source] if isinstance(gcs_source, str) else gcs_source + self._import_schema_uri = import_schema_uri + self._data_item_labels = data_item_labels + + @property + def import_data_config(self) -> gca_dataset.ImportDataConfig: + """Import Data Config.""" + return gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=self._gcs_source), + import_schema_uri=self._import_schema_uri, + data_item_labels=self._data_item_labels, + ) + + +def create_datasource( + metadata_schema_uri: str, + import_schema_uri: Optional[str] = None, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + data_item_labels: Optional[Dict] = None, +) -> Datasource: + """Creates a datasource + Args: + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + gcs_source (Union[str, Sequence[str]]): + The Google Cloud Storage location for the input content. + Google Cloud Storage URI(-s) to the input file(s). May contain + wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + ``import_schema_uri``, + e.g. jsonl file. + + Returns: + datasource (Datasource) + + Raises: + ValueError when below scenarios happen + - import_schema_uri is identified for creating TabularDatasource + - either import_schema_uri or gcs_source is missing for creating NonTabularDatasourceImportable + """ + + if metadata_schema_uri == schema.dataset.metadata.tabular: + if import_schema_uri: + raise ValueError("tabular dataset does not support data import.") + return TabularDatasource(gcs_source, bq_source) + + if not import_schema_uri and not gcs_source: + return NonTabularDatasource() + elif import_schema_uri and gcs_source: + return NonTabularDatasourceImportable( + gcs_source, import_schema_uri, data_item_labels + ) + else: + raise ValueError( + "nontabular dataset requires both import_schema_uri and gcs_source for data import." + ) diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py new file mode 100644 index 0000000000..922ce8930b --- /dev/null +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -0,0 +1,577 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union, List + +from google.api_core import operation +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.services import dataset_service_client +from google.cloud.aiplatform.compat.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + io as gca_io, +) +from google.cloud.aiplatform.datasets import _datasources + +_LOGGER = base.Logger(__name__) + + +class _Dataset(base.AiPlatformResourceNounWithFutureManager): + """Managed dataset resource for AI Platform""" + + client_class = utils.DatasetClientWithOverride + _is_client_prediction_client = False + _resource_noun = "datasets" + _getter_method = "get_dataset" + _list_method = "list_datasets" + _delete_method = "delete_dataset" + + _supported_metadata_schema_uris: Tuple[str] = () + + def __init__( + self, + dataset_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed dataset given a dataset name or ID. + + Args: + dataset_name (str): + Required. A fully-qualified dataset resource name or dataset ID. + Example: "projects/123/locations/us-central1/datasets/456" or + "456" when project and location are initialized or passed. + project (str): + Optional project to retrieve dataset from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve dataset from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=dataset_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=dataset_name) + self._validate_metadata_schema_uri() + + @property + def metadata_schema_uri(self) -> str: + """The metadata schema uri of this dataset resource.""" + return self._gca_resource.metadata_schema_uri + + def _validate_metadata_schema_uri(self) -> None: + """Validate the metadata_schema_uri of retrieved dataset resource. + + Raises: + ValueError if the dataset type of the retrieved dataset resource is + not supported by the class. + """ + if self._supported_metadata_schema_uris and ( + self.metadata_schema_uri not in self._supported_metadata_schema_uris + ): + raise ValueError( + f"{self.__class__.__name__} class can not be used to retrieve " + f"dataset resource {self.resource_name}, check the dataset type" + ) + + @classmethod + def create( + cls, + display_name: str, + metadata_schema_uri: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "_Dataset": + """Creates a new dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + dataset (Dataset): + Instantiated representation of the managed dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + bq_source=bq_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + @classmethod + @base.optional_sync() + def _create_and_import( + cls, + api_client: dataset_service_client.DatasetServiceClient, + parent: str, + display_name: str, + metadata_schema_uri: str, + datasource: _datasources.Datasource, + project: str, + location: str, + credentials: Optional[auth_credentials.Credentials], + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, + sync: bool = True, + ) -> "_Dataset": + """Creates a new dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + api_client (dataset_service_client.DatasetServiceClient): + An instance of DatasetServiceClient with the correct api_endpoint + already set based on user's preferences. + parent (str): + Required. Also known as common location path, that usually contains the + project and location that the user provided to the upstream method. + Example: "projects/my-prj/locations/us-central1" + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + datasource (_datasources.Datasource): + Required. Datasource for creating a dataset for AI Platform. + project (str): + Required. Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Required. Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]): + Optional. The Cloud KMS customer managed encryption key used to protect the dataset. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + dataset (Dataset): + Instantiated representation of the managed dataset resource. + """ + + create_dataset_lro = cls._create( + api_client=api_client, + parent=parent, + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + request_metadata=request_metadata, + encryption_spec=encryption_spec, + ) + + _LOGGER.log_create_with_lro(cls, create_dataset_lro) + + created_dataset = create_dataset_lro.result() + + _LOGGER.log_create_complete(cls, created_dataset, "ds") + + dataset_obj = cls( + dataset_name=created_dataset.name, + project=project, + location=location, + credentials=credentials, + ) + + # Import if import datasource is DatasourceImportable + if isinstance(datasource, _datasources.DatasourceImportable): + dataset_obj._import_and_wait(datasource) + + return dataset_obj + + def _import_and_wait(self, datasource): + _LOGGER.log_action_start_against_resource( + "Importing", "data", self, + ) + + import_lro = self._import(datasource=datasource) + + _LOGGER.log_action_started_against_resource_with_lro( + "Import", "data", self.__class__, import_lro + ) + + import_lro.result() + + _LOGGER.log_action_completed_against_resource("data", "imported", self) + + @classmethod + def _create( + cls, + api_client: dataset_service_client.DatasetServiceClient, + parent: str, + display_name: str, + metadata_schema_uri: str, + datasource: _datasources.Datasource, + request_metadata: Sequence[Tuple[str, str]] = (), + encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, + ) -> operation.Operation: + """Creates a new managed dataset by directly calling API client. + + Args: + api_client (dataset_service_client.DatasetServiceClient): + An instance of DatasetServiceClient with the correct api_endpoint + already set based on user's preferences. + parent (str): + Required. Also known as common location path, that usually contains the + project and location that the user provided to the upstream method. + Example: "projects/my-prj/locations/us-central1" + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + metadata_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing additional information about the Dataset. The schema + is defined as an OpenAPI 3.0.2 Schema Object. The schema files + that can be used here are found in gs://google-cloud- + aiplatform/schema/dataset/metadata/. + datasource (_datasources.Datasource): + Required. Datasource for creating a dataset for AI Platform. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the create_dataset + request as metadata. Usually to specify special dataset config. + encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]): + Optional. The Cloud KMS customer managed encryption key used to protect the dataset. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + Returns: + operation (Operation): + An object representing a long-running operation. + """ + + gapic_dataset = gca_dataset.Dataset( + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + metadata=datasource.dataset_metadata, + encryption_spec=encryption_spec, + ) + + return api_client.create_dataset( + parent=parent, dataset=gapic_dataset, metadata=request_metadata + ) + + def _import( + self, datasource: _datasources.DatasourceImportable, + ) -> operation.Operation: + """Imports data into managed dataset by directly calling API client. + + Args: + datasource (_datasources.DatasourceImportable): + Required. Datasource for importing data to an existing dataset for AI Platform. + + Returns: + operation (Operation): + An object representing a long-running operation. + """ + return self.api_client.import_data( + name=self.resource_name, import_configs=[datasource.import_data_config] + ) + + @base.optional_sync(return_input_arg="self") + def import_data( + self, + gcs_source: Union[str, Sequence[str]], + import_schema_uri: str, + data_item_labels: Optional[Dict] = None, + sync: bool = True, + ) -> "_Dataset": + """Upload data to existing managed dataset. + + Args: + gcs_source (Union[str, Sequence[str]]): + Required. Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + e.g. jsonl file. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + dataset (Dataset): + Instantiated representation of the managed dataset resource. + """ + datasource = _datasources.create_datasource( + metadata_schema_uri=self.metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + self._import_and_wait(datasource=datasource) + return self + + # TODO(b/174751568) add optional sync support + def export_data(self, output_dir: str) -> Sequence[str]: + """Exports data to output dir to GCS. + + Args: + output_dir (str): + Required. The Google Cloud Storage location where the output is to + be written to. In the given directory a new directory will be + created with name: + ``export-data--`` + where timestamp is in YYYYMMDDHHMMSS format. All export + output will be written into that directory. Inside that + directory, annotations with the same schema will be grouped + into sub directories which are named with the corresponding + annotations' schema title. Inside these sub directories, a + schema.yaml will be created to describe the output format. + + If the uri doesn't end with '/', a '/' will be automatically + appended. The directory is created if it doesn't exist. + + Returns: + exported_files (Sequence[str]): + All of the files that are exported in this export operation. + """ + self.wait() + + # TODO(b/171311614): Add support for BiqQuery export path + export_data_config = gca_dataset.ExportDataConfig( + gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir) + ) + + _LOGGER.log_action_start_against_resource("Exporting", "data", self) + + export_lro = self.api_client.export_data( + name=self.resource_name, export_config=export_data_config + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Export", "data", self.__class__, export_lro + ) + + export_data_response = export_lro.result() + + _LOGGER.log_action_completed_against_resource("data", "export", self) + + return export_data_response.exported_files + + def update(self): + raise NotImplementedError("Update dataset has not been implemented yet") + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Dataset resource. + + Example Usage: + + aiplatform.TabularDataset.list( + filter='labels.my_key="my_value"', + order_by='display_name' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[base.AiPlatformResourceNoun] - A list of Dataset resource objects + """ + + dataset_subclass_filter = ( + lambda gapic_obj: gapic_obj.metadata_schema_uri + in cls._supported_metadata_schema_uris + ) + + return cls._list_with_local_order( + cls_filter=dataset_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/datasets/image_dataset.py b/google/cloud/aiplatform/datasets/image_dataset.py new file mode 100644 index 0000000000..cea13014d8 --- /dev/null +++ b/google/cloud/aiplatform/datasets/image_dataset.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class ImageDataset(datasets._Dataset): + """Managed image dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.image, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "ImageDataset": + """Creates a new image dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + image_dataset (ImageDataset): + Instantiated representation of the managed image dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.image + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py new file mode 100644 index 0000000000..3dd217aad7 --- /dev/null +++ b/google/cloud/aiplatform/datasets/tabular_dataset.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class TabularDataset(datasets._Dataset): + """Managed tabular dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.tabular, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bq_source: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "TabularDataset": + """Creates a new tabular dataset. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + bq_source (str): + BigQuery URI to the input table. + example: + "bq://project.dataset.table_name" + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + tabular_dataset (TabularDataset): + Instantiated representation of the managed tabular dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.tabular + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + gcs_source=gcs_source, + bq_source=bq_source, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + def import_data(self): + raise NotImplementedError( + f"{self.__class__.__name__} class does not support 'import_data'" + ) diff --git a/google/cloud/aiplatform/datasets/text_dataset.py b/google/cloud/aiplatform/datasets/text_dataset.py new file mode 100644 index 0000000000..2b791e5c82 --- /dev/null +++ b/google/cloud/aiplatform/datasets/text_dataset.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class TextDataset(datasets._Dataset): + """Managed text dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.text, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "TextDataset": + """Creates a new text dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Example Usage: + ds = aiplatform.TextDataset.create( + display_name='my-dataset', + gcs_source='gs://my-bucket/dataset.csv', + import_schema_uri=aiplatform.schema.dataset.ioformat.text.multi_label_classification + ) + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + text_dataset (TextDataset): + Instantiated representation of the managed text dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.text + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) diff --git a/google/cloud/aiplatform/datasets/video_dataset.py b/google/cloud/aiplatform/datasets/video_dataset.py new file mode 100644 index 0000000000..c50298f99a --- /dev/null +++ b/google/cloud/aiplatform/datasets/video_dataset.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Sequence, Dict, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform.datasets import _datasources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + + +class VideoDataset(datasets._Dataset): + """Managed video dataset resource for AI Platform""" + + _supported_metadata_schema_uris: Optional[Tuple[str]] = ( + schema.dataset.metadata.video, + ) + + @classmethod + def create( + cls, + display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + import_schema_uri: Optional[str] = None, + data_item_labels: Optional[Dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "VideoDataset": + """Creates a new video dataset and optionally imports data into dataset when + source and import_schema_uri are passed. + + Args: + display_name (str): + Required. The user-defined name of the Dataset. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source (Union[str, Sequence[str]]): + Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + examples: + str: "gs://bucket/file.csv" + Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"] + import_schema_uri (str): + Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + data_item_labels (Dict): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file refenced by + [import_schema_uri][google.cloud.aiplatform.v1beta1.ImportDataConfig.import_schema_uri], + e.g. jsonl file. + project (str): + Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + request_metadata (Sequence[Tuple[str, str]]): + Strings which should be sent along with the request as metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the dataset. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + video_dataset (VideoDataset): + Instantiated representation of the managed video dataset resource. + + """ + + utils.validate_display_name(display_name) + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + metadata_schema_uri = schema.dataset.metadata.video + + datasource = _datasources.create_datasource( + metadata_schema_uri=metadata_schema_uri, + import_schema_uri=import_schema_uri, + gcs_source=gcs_source, + data_item_labels=data_item_labels, + ) + + return cls._create_and_import( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + datasource=datasource, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + request_metadata=request_metadata, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) diff --git a/google/cloud/aiplatform/explain/__init__.py b/google/cloud/aiplatform/explain/__init__.py new file mode 100644 index 0000000000..61b9181834 --- /dev/null +++ b/google/cloud/aiplatform/explain/__init__.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat.types import ( + explanation_metadata_v1beta1 as explanation_metadata, + explanation_v1beta1 as explanation, +) + +ExplanationMetadata = explanation_metadata.ExplanationMetadata + +# ExplanationMetadata subclasses +InputMetadata = ExplanationMetadata.InputMetadata +OutputMetadata = ExplanationMetadata.OutputMetadata + +# InputMetadata subclasses +Encoding = InputMetadata.Encoding +FeatureValueDomain = InputMetadata.FeatureValueDomain +Visualization = InputMetadata.Visualization + + +ExplanationParameters = explanation.ExplanationParameters +FeatureNoiseSigma = explanation.FeatureNoiseSigma + +# Classes used by ExplanationParameters +IntegratedGradientsAttribution = explanation.IntegratedGradientsAttribution + +SampledShapleyAttribution = explanation.SampledShapleyAttribution +SmoothGradConfig = explanation.SmoothGradConfig +XraiAttribution = explanation.XraiAttribution + + +__all__ = ( + "Encoding", + "ExplanationMetadata", + "ExplanationParameters", + "FeatureNoiseSigma", + "FeatureValueDomain", + "InputMetadata", + "IntegratedGradientsAttribution", + "OutputMetadata", + "SampledShapleyAttribution", + "SmoothGradConfig", + "Visualization", + "XraiAttribution", +) diff --git a/google/cloud/aiplatform/helpers/_decorators.py b/google/cloud/aiplatform/helpers/_decorators.py index 5d9aa28bea..95aac31c4f 100644 --- a/google/cloud/aiplatform/helpers/_decorators.py +++ b/google/cloud/aiplatform/helpers/_decorators.py @@ -68,3 +68,5 @@ def _from_map(map_): marshal = Marshal(name="google.cloud.aiplatform.v1beta1") marshal.register(Value, ConversionValueRule(marshal=marshal)) +marshal = Marshal(name="google.cloud.aiplatform.v1") +marshal.register(Value, ConversionValueRule(marshal=marshal)) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py new file mode 100644 index 0000000000..b84a006d02 --- /dev/null +++ b/google/cloud/aiplatform/initializer.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from concurrent import futures +import logging +import pkg_resources +import os +from typing import Optional, Type, Union + +from google.api_core import client_options +from google.api_core import gapic_v1 +import google.auth +from google.auth import credentials as auth_credentials +from google.auth.exceptions import GoogleAuthError + +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.types import ( + encryption_spec as gca_encryption_spec_compat, + encryption_spec_v1 as gca_encryption_spec_v1, + encryption_spec_v1beta1 as gca_encryption_spec_v1beta1, +) + + +class _Config: + """Stores common parameters and options for API calls.""" + + def __init__(self): + self._project = None + self._experiment = None + self._location = None + self._staging_bucket = None + self._credentials = None + self._encryption_spec_key_name = None + + def init( + self, + *, + project: Optional[str] = None, + location: Optional[str] = None, + experiment: Optional[str] = None, + staging_bucket: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + ): + """Updates common initalization parameters with provided options. + + Args: + project (str): The default project to use when making API calls. + location (str): The default location to use when making API calls. If not + set defaults to us-central-1 + experiment (str): The experiment to assign + staging_bucket (str): The default staging bucket to use to stage artifacts + when making API calls. In the form gs://... + credentials (google.auth.crendentials.Credentials): The default custom + credentials to use when making API calls. If not provided crendentials + will be ascertained from the environment. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect a resource. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this resource and all sub-resources will be secured by this key. + """ + if project: + self._project = project + if location: + utils.validate_region(location) + self._location = location + if experiment: + logging.warning("Experiments currently not supported.") + self._experiment = experiment + if staging_bucket: + self._staging_bucket = staging_bucket + if credentials: + self._credentials = credentials + if encryption_spec_key_name: + self._encryption_spec_key_name = encryption_spec_key_name + + def get_encryption_spec( + self, + encryption_spec_key_name: Optional[str], + select_version: Optional[str] = compat.DEFAULT_VERSION, + ) -> Optional[ + Union[ + gca_encryption_spec_v1.EncryptionSpec, + gca_encryption_spec_v1beta1.EncryptionSpec, + ] + ]: + """Creates a gca_encryption_spec.EncryptionSpec instance from the given key name. + If the provided key name is None, it uses the default key name if provided. + + Args: + encryption_spec_key_name (Optional[str]): The default encryption key name to use when creating resources. + select_version: The default version is set to compat.DEFAULT_VERSION + """ + kms_key_name = encryption_spec_key_name or self.encryption_spec_key_name + encryption_spec = None + if kms_key_name: + gca_encryption_spec = gca_encryption_spec_compat + if select_version == compat.V1BETA1: + gca_encryption_spec = gca_encryption_spec_v1beta1 + encryption_spec = gca_encryption_spec.EncryptionSpec( + kms_key_name=kms_key_name + ) + return encryption_spec + + @property + def project(self) -> str: + """Default project.""" + if self._project: + return self._project + + project_not_found_exception_str = ( + "Unable to find your project. Please provide a project ID by:" + "\n- Passing a constructor argument" + "\n- Using aiplatform.init()" + "\n- Setting a GCP environment variable" + ) + + try: + _, project_id = google.auth.default() + except GoogleAuthError: + raise GoogleAuthError(project_not_found_exception_str) + + if not project_id: + raise ValueError(project_not_found_exception_str) + + return project_id + + @property + def location(self) -> str: + """Default location.""" + return self._location or constants.DEFAULT_REGION + + @property + def experiment(self) -> Optional[str]: + """Default experiment, if provided.""" + return self._experiment + + @property + def staging_bucket(self) -> Optional[str]: + """Default staging bucket, if provided.""" + return self._staging_bucket + + @property + def credentials(self) -> Optional[auth_credentials.Credentials]: + """Default credentials.""" + if self._credentials: + return self._credentials + logger = logging.getLogger("google.auth._default") + logging_warning_filter = utils.LoggingWarningFilter() + logger.addFilter(logging_warning_filter) + credentials, _ = google.auth.default() + logger.removeFilter(logging_warning_filter) + return credentials + + @property + def encryption_spec_key_name(self) -> Optional[str]: + """Default encryption spec key name, if provided.""" + return self._encryption_spec_key_name + + def get_client_options( + self, location_override: Optional[str] = None + ) -> client_options.ClientOptions: + """Creates GAPIC client_options using location and type. + + Args: + location_override (str): + Set this parameter to get client options for a location different from + location set by initializer. Must be a GCP region supported by AI + Platform (Unified). + + Returns: + clients_options (google.api_core.client_options.ClientOptions): + A ClientOptions object set with regionalized API endpoint, i.e. + { "api_endpoint": "us-central1-aiplatform.googleapis.com" } or + { "api_endpoint": "asia-east1-aiplatform.googleapis.com" } + """ + if not (self.location or location_override): + raise ValueError( + "No location found. Provide or initialize SDK with a location." + ) + + region = location_override or self.location + region = region.lower() + + utils.validate_region(region) + + return client_options.ClientOptions( + api_endpoint=f"{region}-{constants.API_BASE_PATH}" + ) + + def common_location_path( + self, project: Optional[str] = None, location: Optional[str] = None + ) -> str: + """Get parent resource with optional project and location override. + + Args: + project (str): GCP project. If not provided will use the current project. + location (str): Location. If not provided will use the current location. + Returns: + resource_parent: Formatted parent resource string. + """ + if location: + utils.validate_region(location) + + return "/".join( + [ + "projects", + project or self.project, + "locations", + location or self.location, + ] + ) + + def create_client( + self, + client_class: Type[utils.AiPlatformServiceClientWithOverride], + credentials: Optional[auth_credentials.Credentials] = None, + location_override: Optional[str] = None, + prediction_client: bool = False, + ) -> utils.AiPlatformServiceClientWithOverride: + """Instantiates a given AiPlatformServiceClient with optional overrides. + + Args: + client_class (utils.AiPlatformServiceClientWithOverride): + (Required) An AI Platform Service Client with optional overrides. + credentials (auth_credentials.Credentials): + Custom auth credentials. If not provided will use the current config. + location_override (str): Optional location override. + prediction_client (str): Optional flag to use a prediction endpoint. + Returns: + client: Instantiated AI Platform Service client with optional overrides + """ + gapic_version = pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version + client_info = gapic_v1.client_info.ClientInfo( + gapic_version=gapic_version, user_agent=f"model-builder/{gapic_version}" + ) + + kwargs = { + "credentials": credentials or self.credentials, + "client_options": self.get_client_options( + location_override=location_override + ), + "client_info": client_info, + } + + return client_class(**kwargs) + + +# global config to store init parameters: ie, aiplatform.init(project=..., location=...) +global_config = _Config() + +global_pool = futures.ThreadPoolExecutor( + max_workers=min(32, max(4, (os.cpu_count() or 0) * 5)) +) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py new file mode 100644 index 0000000000..a7f2bbd31d --- /dev/null +++ b/google/cloud/aiplatform/jobs.py @@ -0,0 +1,795 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Iterable, Optional, Union, Sequence, Dict, List + +import abc +import sys +import time +import logging + +from google.cloud import storage +from google.cloud import bigquery + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.types import ( + io as gca_io_compat, + io_v1beta1 as gca_io_v1beta1, + job_state as gca_job_state, + batch_prediction_job as gca_bp_job_compat, + batch_prediction_job_v1 as gca_bp_job_v1, + batch_prediction_job_v1beta1 as gca_bp_job_v1beta1, + machine_resources as gca_machine_resources_compat, + machine_resources_v1beta1 as gca_machine_resources_v1beta1, + explanation_v1beta1 as gca_explanation_v1beta1, +) + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) +_LOGGER = base.Logger(__name__) + +_JOB_COMPLETE_STATES = ( + gca_job_state.JobState.JOB_STATE_SUCCEEDED, + gca_job_state.JobState.JOB_STATE_FAILED, + gca_job_state.JobState.JOB_STATE_CANCELLED, + gca_job_state.JobState.JOB_STATE_PAUSED, +) + +_JOB_ERROR_STATES = ( + gca_job_state.JobState.JOB_STATE_FAILED, + gca_job_state.JobState.JOB_STATE_CANCELLED, +) + + +class _Job(base.AiPlatformResourceNounWithFutureManager): + """ + Class that represents a general Job resource in AI Platform (Unified). + Cannot be directly instantiated. + + Serves as base class to specific Job types, i.e. BatchPredictionJob or + DataLabelingJob to re-use shared functionality. + + Subclasses requires one class attribute: + + _getter_method (str): The name of JobServiceClient getter method for specific + Job type, i.e. 'get_custom_job' for CustomJob + _cancel_method (str): The name of the specific JobServiceClient cancel method + _delete_method (str): The name of the specific JobServiceClient delete method + """ + + client_class = utils.JobpointClientWithOverride + _is_client_prediction_client = False + + def __init__( + self, + job_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """ + Retrives Job subclass resource by calling a subclass-specific getter method. + + Args: + job_name (str): + Required. A fully-qualified job resource name or job ID. + Example: "projects/123/locations/us-central1/batchPredictionJobs/456" or + "456" when project, location and job_type are initialized or passed. + project: Optional[str] = None, + Optional project to retrieve Job subclass from. If not set, + project set in aiplatform.init will be used. + location: Optional[str] = None, + Optional location to retrieve Job subclass from. If not set, + location set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials] = None, + Custom credentials to use. If not set, credentials set in + aiplatform.init will be used. + """ + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=job_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=job_name) + + @property + def state(self) -> gca_job_state.JobState: + """Fetch Job again and return the current JobState. + + Returns: + state (job_state.JobState): + Enum that describes the state of a AI Platform job. + """ + + # Fetch the Job again for most up-to-date job state + self._sync_gca_resource() + + return self._gca_resource.state + + @property + @abc.abstractmethod + def _job_type(cls) -> str: + """Job type.""" + pass + + @property + @abc.abstractmethod + def _cancel_method(cls) -> str: + """Name of cancellation method for cancelling the specific job type.""" + pass + + def _dashboard_uri(self) -> Optional[str]: + """Helper method to compose the dashboard uri where job can be viewed.""" + fields = utils.extract_fields_from_resource_name(self.resource_name) + url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}" + return url + + def _block_until_complete(self): + """Helper method to block and check on job until complete. + + Raises: + RuntimeError: If job failed or cancelled. + + """ + + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + log_wait = 5 + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + previous_time = time.time() + while self.state not in _JOB_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + time.sleep(wait) + + _LOGGER.log_action_completed_against_resource("", "run", self) + + # Error is only populated when the job state is + # JOB_STATE_FAILED or JOB_STATE_CANCELLED. + if self.state in _JOB_ERROR_STATES: + raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Job Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of Job resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + def cancel(self) -> None: + """Cancels this Job. Success of cancellation is not guaranteed. Use `Job.state` + property to verify if cancellation was successful.""" + + _LOGGER.log_action_start_against_resource("Cancelling", "run", self) + getattr(self.api_client, self._cancel_method)(name=self.resource_name) + + +class BatchPredictionJob(_Job): + + _resource_noun = "batchPredictionJobs" + _getter_method = "get_batch_prediction_job" + _list_method = "list_batch_prediction_jobs" + _cancel_method = "cancel_batch_prediction_job" + _delete_method = "delete_batch_prediction_job" + _job_type = "batch-predictions" + + def __init__( + self, + batch_prediction_job_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """ + Retrieves a BatchPredictionJob resource and instantiates its representation. + + Args: + batch_prediction_job_name (str): + Required. A fully-qualified BatchPredictionJob resource name or ID. + Example: "projects/.../locations/.../batchPredictionJobs/456" or + "456" when project and location are initialized or passed. + project: Optional[str] = None, + Optional project to retrieve BatchPredictionJob from. If not set, + project set in aiplatform.init will be used. + location: Optional[str] = None, + Optional location to retrieve BatchPredictionJob from. If not set, + location set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials] = None, + Custom credentials to use. If not set, credentials set in + aiplatform.init will be used. + """ + + super().__init__( + job_name=batch_prediction_job_name, + project=project, + location=location, + credentials=credentials, + ) + + @classmethod + def create( + cls, + job_display_name: str, + model_name: str, + instances_format: str = "jsonl", + predictions_format: str = "jsonl", + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bigquery_source: Optional[str] = None, + gcs_destination_prefix: Optional[str] = None, + bigquery_destination_prefix: Optional[str] = None, + model_parameters: Optional[Dict] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + starting_replica_count: Optional[int] = None, + max_replica_count: Optional[int] = None, + generate_explanation: Optional[bool] = False, + explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, + explanation_parameters: Optional[ + "aiplatform.explain.ExplanationParameters" + ] = None, + labels: Optional[dict] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> "BatchPredictionJob": + """Create a batch prediction job. + + Args: + job_display_name (str): + Required. The user-defined name of the BatchPredictionJob. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + model_name (str): + Required. A fully-qualified model resource name or model ID. + Example: "projects/123/locations/us-central1/models/456" or + "456" when project and location are initialized or passed. + instances_format (str): + Required. The format in which instances are given, must be one + of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip", + or "file-list". Default is "jsonl" when using `gcs_source`. If a + `bigquery_source` is provided, this is overriden to "bigquery". + predictions_format (str): + Required. The format in which AI Platform gives the + predictions, must be one of "jsonl", "csv", or "bigquery". + Default is "jsonl" when using `gcs_destination_prefix`. If a + `bigquery_destination_prefix` is provided, this is overriden to + "bigquery". + gcs_source (Optional[Sequence[str]]): + Google Cloud Storage URI(-s) to your instances to run + batch prediction on. They must match `instances_format`. + May contain wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + bigquery_source (Optional[str]): + BigQuery URI to a table, up to 2000 characters long. For example: + `projectId.bqDatasetId.bqTableId` + gcs_destination_prefix (Optional[str]): + The Google Cloud Storage location of the directory where the + output is to be written to. In the given directory a new + directory is created. Its name is + ``prediction--``, where + timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. + Inside of it files ``predictions_0001.``, + ``predictions_0002.``, ..., + ``predictions_N.`` are created where + ```` depends on chosen ``predictions_format``, + and N may equal 0001 and depends on the total number of + successfully predicted instances. If the Model has both + ``instance`` and ``prediction`` schemata defined then each such + file contains predictions as per the ``predictions_format``. + If prediction for any instance failed (partially or + completely), then an additional ``errors_0001.``, + ``errors_0002.``,..., ``errors_N.`` + files are created (N depends on total number of failed + predictions). These files contain the failed instances, as + per their schema, followed by an additional ``error`` field + which as value has ```google.rpc.Status`` `__ + containing only ``code`` and ``message`` fields. + bigquery_destination_prefix (Optional[str]): + The BigQuery project location where the output is to be + written to. In the given project a new dataset is created + with name + ``prediction__`` where + is made BigQuery-dataset-name compatible (for example, most + special characters become underscores), and timestamp is in + YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the + dataset two tables will be created, ``predictions``, and + ``errors``. If the Model has both ``instance`` and ``prediction`` + schemata defined then the tables have columns as follows: + The ``predictions`` table contains instances for which the + prediction succeeded, it has columns as per a concatenation + of the Model's instance and prediction schemata. The + ``errors`` table contains rows for which the prediction has + failed, it has instance columns, as per the instance schema, + followed by a single "errors" column, which as values has + ```google.rpc.Status`` `__ represented as a STRUCT, + and containing only ``code`` and ``message``. + model_parameters (Optional[Dict]): + The parameters that govern the predictions. The schema of + the parameters may be specified via the Model's `parameters_schema_uri`. + machine_type (Optional[str]): + The type of machine for running batch prediction on + dedicated resources. Not specifying machine type will result in + batch prediction job being run with automatic resources. + accelerator_type (Optional[str]): + The type of accelerator(s) that may be attached + to the machine as per `accelerator_count`. Only used if + `machine_type` is set. + accelerator_count (Optional[int]): + The number of accelerators to attach to the + `machine_type`. Only used if `machine_type` is set. + starting_replica_count (Optional[int]): + The number of machine replicas used at the start of the batch + operation. If not set, AI Platform decides starting number, not + greater than `max_replica_count`. Only used if `machine_type` is + set. + max_replica_count (Optional[int]): + The maximum number of machine replicas the batch operation may + be scaled to. Only used if `machine_type` is set. + Default is 10. + generate_explanation (bool): + Optional. Generate explanation along with the batch prediction + results. This will cause the batch prediction output to include + explanations based on the `prediction_format`: + - `bigquery`: output includes a column named `explanation`. The value + is a struct that conforms to the [aiplatform.gapic.Explanation] object. + - `jsonl`: The JSON objects on each line include an additional entry + keyed `explanation`. The value of the entry is a JSON object that + conforms to the [aiplatform.gapic.Explanation] object. + - `csv`: Generating explanations for CSV format is not supported. + explanation_metadata (aiplatform.explain.ExplanationMetadata): + Optional. Explanation metadata configuration for this BatchPredictionJob. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_metadata`. + All fields of `explanation_metadata` are optional in the request. If + a field of the `explanation_metadata` object is not populated, the + corresponding field of the `Model.explanation_metadata` object is inherited. + For more details, see `Ref docs ` + explanation_parameters (aiplatform.explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_parameters`. + All fields of `explanation_parameters` are optional in the request. If + a field of the `explanation_parameters` object is not populated, the + corresponding field of the `Model.explanation_parameters` object is inherited. + For more details, see `Ref docs ` + labels (Optional[dict]): + The labels with user-defined metadata to organize your + BatchPredictionJobs. Label keys and values can be no longer than + 64 characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information and examples of labels. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to create this batch prediction + job. Overrides credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the job. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If this is set, then all + resources created by the BatchPredictionJob will + be encrypted with the provided encryption key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + + """ + + utils.validate_display_name(job_display_name) + + model_name = utils.full_resource_name( + resource_name=model_name, + resource_noun="models", + project=project, + location=location, + ) + + # Raise error if both or neither source URIs are provided + if bool(gcs_source) == bool(bigquery_source): + raise ValueError( + "Please provide either a gcs_source or bigquery_source, " + "but not both." + ) + + # Raise error if both or neither destination prefixes are provided + if bool(gcs_destination_prefix) == bool(bigquery_destination_prefix): + raise ValueError( + "Please provide either a gcs_destination_prefix or " + "bigquery_destination_prefix, but not both." + ) + + # Raise error if unsupported instance format is provided + if instances_format not in constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS: + raise ValueError( + f"{predictions_format} is not an accepted instances format " + f"type. Please choose from: {constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS}" + ) + + # Raise error if unsupported prediction format is provided + if predictions_format not in constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS: + raise ValueError( + f"{predictions_format} is not an accepted prediction format " + f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}" + ) + gca_bp_job = gca_bp_job_compat + gca_io = gca_io_compat + gca_machine_resources = gca_machine_resources_compat + select_version = compat.DEFAULT_VERSION + if generate_explanation: + gca_bp_job = gca_bp_job_v1beta1 + gca_io = gca_io_v1beta1 + gca_machine_resources = gca_machine_resources_v1beta1 + select_version = compat.V1BETA1 + + gapic_batch_prediction_job = gca_bp_job.BatchPredictionJob() + + # Required Fields + gapic_batch_prediction_job.display_name = job_display_name + gapic_batch_prediction_job.model = model_name + + input_config = gca_bp_job.BatchPredictionJob.InputConfig() + output_config = gca_bp_job.BatchPredictionJob.OutputConfig() + + if bigquery_source: + input_config.instances_format = "bigquery" + input_config.bigquery_source = gca_io.BigQuerySource() + input_config.bigquery_source.input_uri = bigquery_source + else: + input_config.instances_format = instances_format + input_config.gcs_source = gca_io.GcsSource( + uris=gcs_source if type(gcs_source) == list else [gcs_source] + ) + + if bigquery_destination_prefix: + output_config.predictions_format = "bigquery" + output_config.bigquery_destination = gca_io.BigQueryDestination() + + bq_dest_prefix = bigquery_destination_prefix + + if not bq_dest_prefix.startswith("bq://"): + bq_dest_prefix = f"bq://{bq_dest_prefix}" + + output_config.bigquery_destination.output_uri = bq_dest_prefix + else: + output_config.predictions_format = predictions_format + output_config.gcs_destination = gca_io.GcsDestination( + output_uri_prefix=gcs_destination_prefix + ) + + gapic_batch_prediction_job.input_config = input_config + gapic_batch_prediction_job.output_config = output_config + + # Optional Fields + gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + select_version=select_version, + ) + + if model_parameters: + gapic_batch_prediction_job.model_parameters = model_parameters + + # Custom Compute + if machine_type: + + machine_spec = gca_machine_resources.MachineSpec() + machine_spec.machine_type = machine_type + machine_spec.accelerator_type = accelerator_type + machine_spec.accelerator_count = accelerator_count + + dedicated_resources = gca_machine_resources.BatchDedicatedResources() + + dedicated_resources.machine_spec = machine_spec + dedicated_resources.starting_replica_count = starting_replica_count + dedicated_resources.max_replica_count = max_replica_count + + gapic_batch_prediction_job.dedicated_resources = dedicated_resources + + gapic_batch_prediction_job.manual_batch_tuning_parameters = None + + # User Labels + gapic_batch_prediction_job.labels = labels + + # Explanations + if generate_explanation: + gapic_batch_prediction_job.generate_explanation = generate_explanation + + if explanation_metadata or explanation_parameters: + gapic_batch_prediction_job.explanation_spec = gca_explanation_v1beta1.ExplanationSpec( + metadata=explanation_metadata, parameters=explanation_parameters + ) + + # TODO (b/174502913): Support private feature once released + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + return cls._create( + api_client=api_client, + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + batch_prediction_job=gapic_batch_prediction_job, + generate_explanation=generate_explanation, + project=project or initializer.global_config.project, + location=location or initializer.global_config.location, + credentials=credentials or initializer.global_config.credentials, + sync=sync, + ) + + @classmethod + @base.optional_sync() + def _create( + cls, + api_client: job_service_client.JobServiceClient, + parent: str, + batch_prediction_job: Union[ + gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob + ], + generate_explanation: bool, + project: str, + location: str, + credentials: Optional[auth_credentials.Credentials], + sync: bool = True, + ) -> "BatchPredictionJob": + """Create a batch prediction job. + + Args: + api_client (dataset_service_client.DatasetServiceClient): + Required. An instance of DatasetServiceClient with the correct api_endpoint + already set based on user's preferences. + batch_prediction_job (gca_bp_job.BatchPredictionJob): + Required. a batch prediction job proto for creating a batch prediction job on AI Platform. + generate_explanation (bool): + Required. Generate explanation along with the batch prediction + results. + parent (str): + Required. Also known as common location path, that usually contains the + project and location that the user provided to the upstream method. + Example: "projects/my-prj/locations/us-central1" + project (str): + Required. Project to upload this model to. Overrides project set in + aiplatform.init. + location (str): + Required. Location to upload this model to. Overrides location set in + aiplatform.init. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + + Raises: + ValueError: + If no or multiple source or destinations are provided. Also, if + provided instances_format or predictions_format are not supported + by AI Platform. + + """ + # select v1beta1 if explain else use default v1 + if generate_explanation: + api_client = api_client.select_version(compat.V1BETA1) + + _LOGGER.log_create_with_lro(cls) + + gca_batch_prediction_job = api_client.create_batch_prediction_job( + parent=parent, batch_prediction_job=batch_prediction_job + ) + + batch_prediction_job = cls( + batch_prediction_job_name=gca_batch_prediction_job.name, + project=project, + location=location, + credentials=credentials, + ) + + _LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj") + + _LOGGER.info( + "View Batch Prediction Job:\n%s" % batch_prediction_job._dashboard_uri() + ) + + batch_prediction_job._block_until_complete() + + return batch_prediction_job + + def iter_outputs( + self, bq_max_results: Optional[int] = 100 + ) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: + """Returns an Iterable object to traverse the output files, either a list + of GCS Blobs or a BigQuery RowIterator depending on the output config set + when the BatchPredictionJob was created. + + Args: + bq_max_results: Optional[int] = 100 + Limit on rows to retrieve from prediction table in BigQuery dataset. + Only used when retrieving predictions from a bigquery_destination_prefix. + Default is 100. + + Returns: + Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]: + Either a list of GCS Blob objects within the prediction output + directory or an iterable BigQuery RowIterator with predictions. + + Raises: + RuntimeError: + If BatchPredictionJob is in a JobState other than SUCCEEDED, + since outputs cannot be retrieved until the Job has finished. + NotImplementedError: + If BatchPredictionJob succeeded and output_info does not have a + GCS or BQ output provided. + """ + + if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED: + raise RuntimeError( + f"Cannot read outputs until BatchPredictionJob has succeeded, " + f"current state: {self._gca_resource.state}" + ) + + output_info = self._gca_resource.output_info + + # GCS Destination, return Blobs + if output_info.gcs_output_directory: + + # Build a Storage Client using the same credentials as JobServiceClient + storage_client = storage.Client( + credentials=self.api_client._transport._credentials + ) + + gcs_bucket, gcs_prefix = utils.extract_bucket_and_prefix_from_gcs_path( + output_info.gcs_output_directory + ) + + blobs = storage_client.list_blobs(gcs_bucket, prefix=gcs_prefix) + + return blobs + + # BigQuery Destination, return RowIterator + elif output_info.bigquery_output_dataset: + + # Build a BigQuery Client using the same credentials as JobServiceClient + bq_client = bigquery.Client( + credentials=self.api_client._transport._credentials + ) + + # Format from service is `bq://projectId.bqDatasetId` + bq_dataset = output_info.bigquery_output_dataset + + if bq_dataset.startswith("bq://"): + bq_dataset = bq_dataset[5:] + + # # Split project ID and BQ dataset ID + _, bq_dataset_id = bq_dataset.split(".", 1) + + row_iterator = bq_client.list_rows( + table=f"{bq_dataset_id}.predictions", max_results=bq_max_results + ) + + return row_iterator + + # Unknown Destination type + else: + raise NotImplementedError( + f"Unsupported batch prediction output location, here are details" + f"on your prediction output:\n{output_info}" + ) + + +class CustomJob(_Job): + _resource_noun = "customJobs" + _getter_method = "get_custom_job" + _list_method = "list_custom_job" + _cancel_method = "cancel_custom_job" + _delete_method = "delete_custom_job" + _job_type = "training" + pass + + +class DataLabelingJob(_Job): + _resource_noun = "dataLabelingJobs" + _getter_method = "get_data_labeling_job" + _list_method = "list_data_labeling_jobs" + _cancel_method = "cancel_data_labeling_job" + _delete_method = "delete_data_labeling_job" + _job_type = "labeling-tasks" + pass + + +class HyperparameterTuningJob(_Job): + _resource_noun = "hyperparameterTuningJobs" + _getter_method = "get_hyperparameter_tuning_job" + _list_method = "list_hyperparameter_tuning_jobs" + _cancel_method = "cancel_hyperparameter_tuning_job" + _delete_method = "delete_hyperparameter_tuning_job" + pass diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py new file mode 100644 index 0000000000..d96b681695 --- /dev/null +++ b/google/cloud/aiplatform/models.py @@ -0,0 +1,1997 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto +from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import explain +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import models +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.services import endpoint_service_client + +from google.cloud.aiplatform.compat.types import ( + encryption_spec as gca_encryption_spec, + endpoint as gca_endpoint_compat, + endpoint_v1 as gca_endpoint_v1, + endpoint_v1beta1 as gca_endpoint_v1beta1, + explanation_v1beta1 as gca_explanation_v1beta1, + machine_resources as gca_machine_resources_compat, + machine_resources_v1beta1 as gca_machine_resources_v1beta1, + model as gca_model_compat, + model_v1beta1 as gca_model_v1beta1, + env_var as gca_env_var_compat, + env_var_v1beta1 as gca_env_var_v1beta1, +) + +from google.protobuf import json_format + + +_LOGGER = base.Logger(__name__) + + +class Prediction(NamedTuple): + """Prediction class envelopes returned Model predictions and the Model id. + + Attributes: + predictions: + The predictions that are the output of the predictions + call. The schema of any single prediction may be specified via + Endpoint's DeployedModels' [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + deployed_model_id: + ID of the Endpoint's DeployedModel that served this prediction. + explanations: + The explanations of the Model's predictions. It has the same number + of elements as instances to be explained. Default is None. + """ + + predictions: Dict[str, List] + deployed_model_id: str + explanations: Optional[Sequence[gca_explanation_v1beta1.Explanation]] = None + + +class Endpoint(base.AiPlatformResourceNounWithFutureManager): + + client_class = utils.EndpointClientWithOverride + _is_client_prediction_client = False + _resource_noun = "endpoints" + _getter_method = "get_endpoint" + _list_method = "list_endpoints" + _delete_method = "delete_endpoint" + + def __init__( + self, + endpoint_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an endpoint resource. + + Args: + endpoint_name (str): + Required. A fully-qualified endpoint resource name or endpoint ID. + Example: "projects/123/locations/us-central1/endpoints/456" or + "456" when project and location are initialized or passed. + project (str): + Optional. Project to retrieve endpoint from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve endpoint from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=endpoint_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=endpoint_name) + self._prediction_client = self._instantiate_prediction_client( + location=location or initializer.global_config.location, + credentials=credentials, + ) + + @classmethod + def create( + cls, + display_name: str, + description: Optional[str] = None, + labels: Optional[Dict] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync=True, + ) -> "Endpoint": + """Creates a new endpoint. + + Args: + display_name (str): + Required. The user-defined name of the Endpoint. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + project (str): + Required. Project to retrieve endpoint from. If not set, project + set in aiplatform.init will be used. + location (str): + Required. Location to retrieve endpoint from. If not set, location + set in aiplatform.init will be used. + description (str): + Optional. The description of the Endpoint. + labels (Dict): + Optional. The labels with user-defined metadata to + organize your Endpoints. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Endpoint and all sub-resources of this Endpoint will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + endpoint (endpoint.Endpoint): + Created endpoint. + """ + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + utils.validate_display_name(display_name) + + project = project or initializer.global_config.project + location = location or initializer.global_config.location + + return cls._create( + api_client=api_client, + display_name=display_name, + project=project, + location=location, + description=description, + labels=labels, + metadata=metadata, + credentials=credentials, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name + ), + sync=sync, + ) + + @classmethod + @base.optional_sync() + def _create( + cls, + api_client: endpoint_service_client.EndpointServiceClient, + display_name: str, + project: str, + location: str, + description: Optional[str] = None, + labels: Optional[Dict] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None, + sync=True, + ) -> "Endpoint": + """ + Creates a new endpoint by calling the API client. + Args: + api_client (EndpointServiceClient): + Required. An instance of EndpointServiceClient with the correct + api_endpoint already set based on user's preferences. + display_name (str): + Required. The user-defined name of the Endpoint. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + project (str): + Required. Project to retrieve endpoint from. If not set, project + set in aiplatform.init will be used. + location (str): + Required. Location to retrieve endpoint from. If not set, location + set in aiplatform.init will be used. + description (str): + Optional. The description of the Endpoint. + labels (Dict): + Optional. The labels with user-defined metadata to + organize your Endpoints. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]): + Optional. The Cloud KMS customer managed encryption key used to protect the dataset. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Dataset and all sub-resources of this Dataset will be secured by this key. + sync (bool): + Whether to create this endpoint synchronously. + Returns: + endpoint (endpoint.Endpoint): + Created endpoint. + """ + + parent = initializer.global_config.common_location_path( + project=project, location=location + ) + + gapic_endpoint = gca_endpoint_compat.Endpoint( + display_name=display_name, + description=description, + labels=labels, + encryption_spec=encryption_spec, + ) + + operation_future = api_client.create_endpoint( + parent=parent, endpoint=gapic_endpoint, metadata=metadata + ) + + _LOGGER.log_create_with_lro(cls, operation_future) + + created_endpoint = operation_future.result() + + _LOGGER.log_create_complete(cls, created_endpoint, "endpoint") + + return cls( + endpoint_name=created_endpoint.name, + project=project, + location=location, + credentials=credentials, + ) + + @staticmethod + def _allocate_traffic( + traffic_split: Dict[str, int], traffic_percentage: int, + ) -> Dict[str, int]: + """ + Allocates desired traffic to new deployed model and scales traffic of + older deployed models. + + Args: + traffic_split (Dict[str, int]): + Required. Current traffic split of deployed models in endpoint. + traffic_percentage (int): + Required. Desired traffic to new deployed model. + Returns: + new_traffic_split (Dict[str, int]): + Traffic split to use. + """ + new_traffic_split = {} + old_models_traffic = 100 - traffic_percentage + if old_models_traffic: + unallocated_traffic = old_models_traffic + for deployed_model in traffic_split: + current_traffic = traffic_split[deployed_model] + new_traffic = int(current_traffic / 100 * old_models_traffic) + new_traffic_split[deployed_model] = new_traffic + unallocated_traffic -= new_traffic + # will likely under-allocate. make total 100. + for deployed_model in new_traffic_split: + if unallocated_traffic == 0: + break + new_traffic_split[deployed_model] += 1 + unallocated_traffic -= 1 + + new_traffic_split["0"] = traffic_percentage + + return new_traffic_split + + @staticmethod + def _unallocate_traffic( + traffic_split: Dict[str, int], deployed_model_id: str, + ) -> Dict[str, int]: + """ + Sets deployed model id's traffic to 0 and scales the traffic of other + deployed models. + + Args: + traffic_split (Dict[str, int]): + Required. Current traffic split of deployed models in endpoint. + deployed_model_id (str): + Required. Desired traffic to new deployed model. + Returns: + new_traffic_split (Dict[str, int]): + Traffic split to use. + """ + new_traffic_split = traffic_split.copy() + del new_traffic_split[deployed_model_id] + deployed_model_id_traffic = traffic_split[deployed_model_id] + traffic_percent_left = 100 - deployed_model_id_traffic + + if traffic_percent_left: + unallocated_traffic = 100 + for deployed_model in new_traffic_split: + current_traffic = traffic_split[deployed_model] + new_traffic = int(current_traffic / traffic_percent_left * 100) + new_traffic_split[deployed_model] = new_traffic + unallocated_traffic -= new_traffic + # will likely under-allocate. make total 100. + for deployed_model in new_traffic_split: + if unallocated_traffic == 0: + break + new_traffic_split[deployed_model] += 1 + unallocated_traffic -= 1 + + new_traffic_split[deployed_model_id] = 0 + + return new_traffic_split + + @staticmethod + def _validate_deploy_args( + min_replica_count: int, + max_replica_count: int, + accelerator_type: Optional[str], + deployed_model_display_name: Optional[str], + traffic_split: Optional[Dict[str, int]], + traffic_percentage: int, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + ): + """Helper method to validate deploy arguments. + + Args: + min_replica_count (int): + Required. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Required. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + accelerator_type (str): + Required. Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + deployed_model_display_name (str): + Required. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_split (Dict[str, int]): + Required. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + traffic_percentage (int): + Required. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + + Raises: + ValueError if Min or Max replica is negative. Traffic percentage > 100 or + < 0. Or if traffic_split does not sum to 100. + + ValueError if either explanation_metadata or explanation_parameters + but not both are specified. + """ + if min_replica_count < 0: + raise ValueError("Min replica cannot be negative.") + if max_replica_count < 0: + raise ValueError("Max replica cannot be negative.") + if deployed_model_display_name is not None: + utils.validate_display_name(deployed_model_display_name) + + if traffic_split is None: + if traffic_percentage > 100: + raise ValueError("Traffic percentage cannot be greater than 100.") + if traffic_percentage < 0: + raise ValueError("Traffic percentage cannot be negative.") + + elif traffic_split: + # TODO(b/172678233) verify every referenced deployed model exists + if sum(traffic_split.values()) != 100: + raise ValueError( + "Sum of all traffic within traffic split needs to be 100." + ) + + if bool(explanation_metadata) != bool(explanation_parameters): + raise ValueError( + "Both `explanation_metadata` and `explanation_parameters` should be specified or None." + ) + + # Raises ValueError if invalid accelerator + if accelerator_type: + utils.validate_accelerator_type(accelerator_type) + + def deploy( + self, + model: "Model", + deployed_model_display_name: Optional[str] = None, + traffic_percentage: int = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: int = 1, + max_replica_count: int = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """ + Deploys a Model to the Endpoint. + + Args: + model (aiplatform.Model): + Required. Model to be deployed. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + + self._validate_deploy_args( + min_replica_count, + max_replica_count, + accelerator_type, + deployed_model_display_name, + traffic_split, + traffic_percentage, + explanation_metadata, + explanation_parameters, + ) + + self._deploy( + model=model, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + sync=sync, + ) + + @base.optional_sync() + def _deploy( + self, + model: "Model", + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """ + Deploys a Model to the Endpoint. + + Args: + model (aiplatform.Model): + Required. Model to be deployed. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Raises: + ValueError if there is not current traffic split and traffic percentage + is not 0 or 100. + """ + _LOGGER.log_action_start_against_resource( + f"Deploying Model {model.resource_name} to", "", self + ) + + self._deploy_call( + self.api_client, + self.resource_name, + model.resource_name, + self._gca_resource.traffic_split, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + ) + + _LOGGER.log_action_completed_against_resource("model", "deployed", self) + + self._sync_gca_resource() + + @classmethod + def _deploy_call( + cls, + api_client: endpoint_service_client.EndpointServiceClient, + endpoint_resource_name: str, + model_resource_name: str, + endpoint_resource_traffic_split: Optional[proto.MapField] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + ): + """Helper method to deploy model to endpoint. + + Args: + api_client (endpoint_service_client.EndpointServiceClient): + Required. endpoint_service_client.EndpointServiceClient to make call. + endpoint_resource_name (str): + Required. Endpoint resource name to deploy model to. + model_resource_name (str): + Required. Model resource name of Model to deploy. + endpoint_resource_traffic_split (proto.MapField): + Optional. Endpoint current resource traffic split. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it + will automatically be increased to be min_replica_count. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Raises: + ValueError if there is not current traffic split and traffic percentage + is not 0 or 100. + ValueError if only `explanation_metadata` or `explanation_parameters` + is specified. + """ + + max_replica_count = max(min_replica_count, max_replica_count) + + if bool(accelerator_type) != bool(accelerator_count): + raise ValueError( + "Both `accelerator_type` and `accelerator_count` should be specified or None." + ) + + gca_endpoint = gca_endpoint_compat + gca_machine_resources = gca_machine_resources_compat + if explanation_metadata and explanation_parameters: + gca_endpoint = gca_endpoint_v1beta1 + gca_machine_resources = gca_machine_resources_v1beta1 + + if machine_type: + machine_spec = gca_machine_resources.MachineSpec(machine_type=machine_type) + + if accelerator_type and accelerator_count: + utils.validate_accelerator_type(accelerator_type) + machine_spec.accelerator_type = accelerator_type + machine_spec.accelerator_count = accelerator_count + + dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=machine_spec, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + ) + deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=dedicated_resources, + model=model_resource_name, + display_name=deployed_model_display_name, + ) + else: + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=model_resource_name, + display_name=deployed_model_display_name, + ) + + # Service will throw error if both metadata and parameters are not provided + if explanation_metadata and explanation_parameters: + api_client = api_client.select_version(compat.V1BETA1) + explanation_spec = gca_endpoint.explanation.ExplanationSpec() + explanation_spec.metadata = explanation_metadata + explanation_spec.parameters = explanation_parameters + deployed_model.explanation_spec = explanation_spec + + if traffic_split is None: + # new model traffic needs to be 100 if no pre-existing models + if not endpoint_resource_traffic_split: + # default scenario + if traffic_percentage == 0: + traffic_percentage = 100 + # verify user specified 100 + elif traffic_percentage < 100: + raise ValueError( + """There are currently no deployed models so the traffic + percentage for this deployed model needs to be 100.""" + ) + traffic_split = cls._allocate_traffic( + traffic_split=dict(endpoint_resource_traffic_split), + traffic_percentage=traffic_percentage, + ) + + operation_future = api_client.deploy_model( + endpoint=endpoint_resource_name, + deployed_model=deployed_model, + traffic_split=traffic_split, + metadata=metadata, + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Deploy", "model", cls, operation_future + ) + + operation_future.result() + + def undeploy( + self, + deployed_model_id: str, + traffic_split: Optional[Dict[str, int]] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """Undeploys a deployed model. + + Proportionally adjusts the traffic_split among the remaining deployed + models of the endpoint. + + Args: + deployed_model_id (str): + Required. The ID of the DeployedModel to be undeployed from the + Endpoint. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + """ + if traffic_split is not None: + if deployed_model_id in traffic_split and traffic_split[deployed_model_id]: + raise ValueError("Model being undeployed should have 0 traffic.") + if sum(traffic_split.values()) != 100: + # TODO(b/172678233) verify every referenced deployed model exists + raise ValueError( + "Sum of all traffic within traffic split needs to be 100." + ) + + self._undeploy( + deployed_model_id=deployed_model_id, + traffic_split=traffic_split, + metadata=metadata, + sync=sync, + ) + + @base.optional_sync() + def _undeploy( + self, + deployed_model_id: str, + traffic_split: Optional[Dict[str, int]] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + sync=True, + ) -> None: + """Undeploys a deployed model. + + Proportionally adjusts the traffic_split among the remaining deployed + models of the endpoint. + + Args: + deployed_model_id (str): + Required. The ID of the DeployedModel to be undeployed from the + Endpoint. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + """ + current_traffic_split = traffic_split or dict(self._gca_resource.traffic_split) + + if deployed_model_id in current_traffic_split: + current_traffic_split = self._unallocate_traffic( + traffic_split=current_traffic_split, + deployed_model_id=deployed_model_id, + ) + current_traffic_split.pop(deployed_model_id) + + _LOGGER.log_action_start_against_resource("Undeploying", "model", self) + + operation_future = self.api_client.undeploy_model( + endpoint=self.resource_name, + deployed_model_id=deployed_model_id, + traffic_split=current_traffic_split, + metadata=metadata, + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Undeploy", "model", self.__class__, operation_future + ) + + # block before returning + operation_future.result() + + _LOGGER.log_action_completed_against_resource("model", "undeployed", self) + + # update local resource + self._sync_gca_resource() + + @staticmethod + def _instantiate_prediction_client( + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> utils.PredictionClientWithOverride: + + """Helper method to instantiates prediction client with optional overrides for this endpoint. + + Args: + location (str): The location of this endpoint. + credentials (google.auth.credentials.Credentials): + Optional custom credentials to use when accessing interacting with + the prediction client. + Returns: + prediction_client (prediction_service_client.PredictionServiceClient): + Initalized prediction client with optional overrides. + """ + return initializer.global_config.create_client( + client_class=utils.PredictionClientWithOverride, + credentials=credentials, + location_override=location, + prediction_client=True, + ) + + def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction: + """Make a prediction against this Endpoint. + + Args: + instances (List): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + parameters (Dict): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + Returns: + prediction: Prediction with returned predictions and Model Id. + + """ + self.wait() + + prediction_response = self._prediction_client.predict( + endpoint=self.resource_name, instances=instances, parameters=parameters + ) + + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in prediction_response.predictions.pb + ], + deployed_model_id=prediction_response.deployed_model_id, + ) + + def explain( + self, + instances: List[Dict], + parameters: Optional[Dict] = None, + deployed_model_id: Optional[str] = None, + ) -> Prediction: + """Make a prediction with explanations against this Endpoint. + + Example usage: + response = my_endpoint.explain(instances=[...]) + my_explanations = response.explanations + + Args: + instances (List): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + parameters (Dict): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + deployed_model_id (str): + Optional. If specified, this ExplainRequest will be served by the + chosen DeployedModel, overriding this Endpoint's traffic split. + Returns: + prediction: Prediction with returned predictions, explanations and Model Id. + """ + self.wait() + + explain_response = self._prediction_client.select_version( + compat.V1BETA1 + ).explain( + endpoint=self.resource_name, + instances=instances, + parameters=parameters, + deployed_model_id=deployed_model_id, + ) + + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in explain_response.predictions.pb + ], + deployed_model_id=explain_response.deployed_model_id, + explanations=explain_response.explanations, + ) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["models.Endpoint"]: + """List all Endpoint resource instances. + + Example Usage: + + aiplatform.Endpoint.list( + filter='labels.my_label="my_label_value" OR display_name=!"old_endpoint"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[models.Endpoint] - A list of Endpoint resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + def list_models( + self, + ) -> Sequence[ + Union[gca_endpoint_v1.DeployedModel, gca_endpoint_v1beta1.DeployedModel] + ]: + """Returns a list of the models deployed to this Endpoint. + + Returns: + deployed_models (Sequence[aiplatform.gapic.DeployedModel]): + A list of the models deployed in this Endpoint. + """ + self._sync_gca_resource() + return self._gca_resource.deployed_models + + def undeploy_all(self, sync: bool = True) -> "Endpoint": + """Undeploys every model deployed to this Endpoint. + + Args: + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + """ + self._sync_gca_resource() + + for deployed_model in self._gca_resource.deployed_models: + self._undeploy(deployed_model_id=deployed_model.id, sync=sync) + + return self + + def delete(self, force: bool = False, sync: bool = True) -> None: + """Deletes this AI Platform Endpoint resource. If force is set to True, + all models on this Endpoint will be undeployed prior to deletion. + + Args: + force (bool): + Required. If force is set to True, all deployed models on this + Endpoint will be undeployed first. Default is False. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Raises: + FailedPrecondition: If models are deployed on this Endpoint and force = False. + """ + if force: + self.undeploy_all(sync=sync) + + super().delete(sync=sync) + + +class Model(base.AiPlatformResourceNounWithFutureManager): + + client_class = utils.ModelClientWithOverride + _is_client_prediction_client = False + _resource_noun = "models" + _getter_method = "get_model" + _list_method = "list_models" + _delete_method = "delete_model" + + @property + def uri(self): + """Uri of the model.""" + return self._gca_resource.artifact_uri + + @property + def description(self): + """Description of the model.""" + return self._gca_resource.description + + def __init__( + self, + model_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves the model resource and instantiates its representation. + + Args: + model_name (str): + Required. A fully-qualified model resource name or model ID. + Example: "projects/123/locations/us-central1/models/456" or + "456" when project and location are initialized or passed. + project (str): + Optional project to retrieve model from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve model from. If not set, location + set in aiplatform.init will be used. + credentials: Optional[auth_credentials.Credentials]=None, + Custom credentials to use to upload this model. If not set, + credentials set in aiplatform.init will be used. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=model_name, + ) + self._gca_resource = self._get_gca_resource(resource_name=model_name) + + # TODO(b/170979552) Add support for predict schemata + # TODO(b/170979926) Add support for metadata and metadata schema + @classmethod + @base.optional_sync() + def upload( + cls, + display_name: str, + serving_container_image_uri: str, + *, + artifact_uri: Optional[str] = None, + serving_container_predict_route: Optional[str] = None, + serving_container_health_route: Optional[str] = None, + description: Optional[str] = None, + serving_container_command: Optional[Sequence[str]] = None, + serving_container_args: Optional[Sequence[str]] = None, + serving_container_environment_variables: Optional[Dict[str, str]] = None, + serving_container_ports: Optional[Sequence[int]] = None, + instance_schema_uri: Optional[str] = None, + parameters_schema_uri: Optional[str] = None, + prediction_schema_uri: Optional[str] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync=True, + ) -> "Model": + """Uploads a model and returns a Model representing the uploaded Model resource. + + Example usage: + + my_model = Model.upload( + display_name='my-model', + artifact_uri='gs://my-model/saved-model' + serving_container_image_uri='tensorflow/serving' + ) + + Args: + display_name (str): + Required. The display name of the Model. The name can be up to 128 + characters long and can be consist of any UTF-8 characters. + serving_container_image_uri (str): + Required. The URI of the Model serving container. + artifact_uri (str): + Optional. The path to the directory containing the Model artifact and + any of its supporting files. Leave blank for custom container prediction. + Not present for AutoML Models. + serving_container_predict_route (str): + Optional. An HTTP path to send prediction requests to the container, and + which must be supported by it. If not specified a default HTTP path will + be used by AI Platform. + serving_container_health_route (str): + Optional. An HTTP path to send health check requests to the container, and which + must be supported by it. If not specified a standard HTTP path will be + used by AI Platform. + description (str): + The description of the model. + serving_container_command: Optional[Sequence[str]]=None, + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + serving_container_args: Optional[Sequence[str]]=None, + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + serving_container_environment_variables: Optional[Dict[str, str]]=None, + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + serving_container_ports: Optional[Sequence[int]]=None, + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + project: Optional[str]=None, + Project to upload this model to. Overrides project set in + aiplatform.init. + location: Optional[str]=None, + Location to upload this model to. Overrides location set in + aiplatform.init. + credentials: Optional[auth_credentials.Credentials]=None, + Custom credentials to use to upload this model. Overrides credentials + set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Returns: + model: Instantiated representation of the uploaded model resource. + Raises: + ValueError if only `explanation_metadata` or `explanation_parameters` + is specified. + """ + utils.validate_display_name(display_name) + + if bool(explanation_metadata) != bool(explanation_parameters): + raise ValueError( + "Both `explanation_metadata` and `explanation_parameters` should be specified or None." + ) + + gca_endpoint = gca_endpoint_compat + gca_model = gca_model_compat + gca_env_var = gca_env_var_compat + if explanation_metadata and explanation_parameters: + gca_endpoint = gca_endpoint_v1beta1 + gca_model = gca_model_v1beta1 + gca_env_var = gca_env_var_v1beta1 + + api_client = cls._instantiate_client(location, credentials) + env = None + ports = None + + if serving_container_environment_variables: + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in serving_container_environment_variables.items() + ] + if serving_container_ports: + ports = [ + gca_model.Port(container_port=port) for port in serving_container_ports + ] + + container_spec = gca_model.ModelContainerSpec( + image_uri=serving_container_image_uri, + command=serving_container_command, + args=serving_container_args, + env=env, + ports=ports, + predict_route=serving_container_predict_route, + health_route=serving_container_health_route, + ) + + model_predict_schemata = None + if any([instance_schema_uri, parameters_schema_uri, prediction_schema_uri]): + model_predict_schemata = gca_model.PredictSchemata( + instance_schema_uri=instance_schema_uri, + parameters_schema_uri=parameters_schema_uri, + prediction_schema_uri=prediction_schema_uri, + ) + + # TODO(b/182388545) initializer.global_config.get_encryption_spec from a sync function + encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + ) + + managed_model = gca_model.Model( + display_name=display_name, + description=description, + container_spec=container_spec, + predict_schemata=model_predict_schemata, + encryption_spec=encryption_spec, + ) + + if artifact_uri: + managed_model.artifact_uri = artifact_uri + + # Override explanation_spec if both required fields are provided + if explanation_metadata and explanation_parameters: + api_client = api_client.select_version(compat.V1BETA1) + explanation_spec = gca_endpoint.explanation.ExplanationSpec() + explanation_spec.metadata = explanation_metadata + explanation_spec.parameters = explanation_parameters + managed_model.explanation_spec = explanation_spec + + lro = api_client.upload_model( + parent=initializer.global_config.common_location_path(project, location), + model=managed_model, + ) + + _LOGGER.log_create_with_lro(cls, lro) + + model_upload_response = lro.result() + + this_model = cls(model_upload_response.model) + + _LOGGER.log_create_complete(cls, this_model._gca_resource, "model") + + return this_model + + # TODO(b/172502059) support deploying with endpoint resource name + def deploy( + self, + endpoint: Optional["Endpoint"] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync=True, + ) -> Endpoint: + """ + Deploys model to endpoint. Endpoint will be created if unspecified. + + Args: + endpoint ("Endpoint"): + Optional. Endpoint to deploy model to. If not specified, endpoint + display name will be model display name+'_endpoint'. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the smaller value of min_replica_count or 1 will + be used. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + endpoint ("Endpoint"): + Endpoint with the deployed model. + + """ + + Endpoint._validate_deploy_args( + min_replica_count, + max_replica_count, + accelerator_type, + deployed_model_display_name, + traffic_split, + traffic_percentage, + explanation_metadata, + explanation_parameters, + ) + + return self._deploy( + endpoint=endpoint, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + encryption_spec_key_name=encryption_spec_key_name + or initializer.global_config.encryption_spec_key_name, + sync=sync, + ) + + @base.optional_sync(return_input_arg="endpoint", bind_future_to_self=False) + def _deploy( + self, + endpoint: Optional["Endpoint"] = None, + deployed_model_display_name: Optional[str] = None, + traffic_percentage: Optional[int] = 0, + traffic_split: Optional[Dict[str, int]] = None, + machine_type: Optional[str] = None, + min_replica_count: Optional[int] = 1, + max_replica_count: Optional[int] = 1, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = (), + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> Endpoint: + """ + Deploys model to endpoint. Endpoint will be created if unspecified. + + Args: + endpoint ("Endpoint"): + Optional. Endpoint to deploy model to. If not specified, endpoint + display name will be model display name+'_endpoint'. + deployed_model_display_name (str): + Optional. The display name of the DeployedModel. If not provided + upon creation, the Model's display_name is used. + traffic_percentage (int): + Optional. Desired traffic to newly deployed model. Defaults to + 0 if there are pre-existing deployed models. Defaults to 100 if + there are no pre-existing deployed models. Negative values should + not be provided. Traffic of previously deployed models at the endpoint + will be scaled down to accommodate new deployed model's traffic. + Should not be provided if traffic_split is provided. + traffic_split (Dict[str, int]): + Optional. A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this map, then it receives + no traffic. The traffic percentage values must add up to 100, or + map must be empty if the Endpoint is to not accept any traffic at + the moment. Key for model being deployed is "0". Should not be + provided if traffic_percentage is provided. + machine_type (str): + Optional. The type of machine. Not specifying machine type will + result in model to be deployed with automatic resources. + min_replica_count (int): + Optional. The minimum number of machine replicas this deployed + model will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas, and as traffic + decreases, some of these extra replicas may be freed. + max_replica_count (int): + Optional. The maximum number of replicas this deployed model may + be deployed on when the traffic against it increases. If requested + value is too large, the deployment will error, but if deployment + succeeds then the ability to scale the model to that many replicas + is guaranteed (barring service outages). If traffic against the + deployed model increases beyond what its replicas at maximum may + handle, a portion of the traffic will be dropped. If this value + is not provided, the smaller value of min_replica_count or 1 will + be used. + accelerator_type (str): + Optional. Hardware accelerator type. Must also set accelerator_count if used. + One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, + NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + Optional. The number of accelerators to attach to a worker replica. + explanation_metadata (explain.ExplanationMetadata): + Optional. Metadata describing the Model's input and output for explanation. + Both `explanation_metadata` and `explanation_parameters` must be + passed together when used. For more details, see + `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + For more details, see `Ref docs ` + metadata (Sequence[Tuple[str, str]]): + Optional. Strings which should be sent along with the request as + metadata. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + endpoint ("Endpoint"): + Endpoint with the deployed model. + """ + + if endpoint is None: + display_name = self.display_name[:118] + "_endpoint" + endpoint = Endpoint.create( + display_name=display_name, + project=self.project, + location=self.location, + credentials=self.credentials, + encryption_spec_key_name=encryption_spec_key_name, + ) + + _LOGGER.log_action_start_against_resource("Deploying model to", "", endpoint) + + Endpoint._deploy_call( + endpoint.api_client, + endpoint.resource_name, + self.resource_name, + endpoint._gca_resource.traffic_split, + deployed_model_display_name=deployed_model_display_name, + traffic_percentage=traffic_percentage, + traffic_split=traffic_split, + machine_type=machine_type, + min_replica_count=min_replica_count, + max_replica_count=max_replica_count, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + metadata=metadata, + ) + + _LOGGER.log_action_completed_against_resource("model", "deployed", endpoint) + + endpoint._sync_gca_resource() + + return endpoint + + def batch_predict( + self, + job_display_name: str, + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bigquery_source: Optional[str] = None, + instances_format: str = "jsonl", + gcs_destination_prefix: Optional[str] = None, + bigquery_destination_prefix: Optional[str] = None, + predictions_format: str = "jsonl", + model_parameters: Optional[Dict] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + starting_replica_count: Optional[int] = None, + max_replica_count: Optional[int] = None, + generate_explanation: Optional[bool] = False, + explanation_metadata: Optional[explain.ExplanationMetadata] = None, + explanation_parameters: Optional[explain.ExplanationParameters] = None, + labels: Optional[dict] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + ) -> jobs.BatchPredictionJob: + """Creates a batch prediction job using this Model and outputs prediction + results to the provided destination prefix in the specified + `predictions_format`. One source and one destination prefix are required. + + Example usage: + + my_model.batch_predict( + job_display_name="prediction-123", + gcs_source="gs://example-bucket/instances.csv", + instances_format="csv", + bigquery_destination_prefix="projectId.bqDatasetId.bqTableId" + ) + + Args: + job_display_name (str): + Required. The user-defined name of the BatchPredictionJob. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + gcs_source: Optional[Sequence[str]] = None + Google Cloud Storage URI(-s) to your instances to run + batch prediction on. They must match `instances_format`. + May contain wildcards. For more information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + bigquery_source: Optional[str] = None + BigQuery URI to a table, up to 2000 characters long. For example: + `projectId.bqDatasetId.bqTableId` + instances_format: str = "jsonl" + Required. The format in which instances are given, must be one + of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip", + or "file-list". Default is "jsonl" when using `gcs_source`. If a + `bigquery_source` is provided, this is overriden to "bigquery". + gcs_destination_prefix: Optional[str] = None + The Google Cloud Storage location of the directory where the + output is to be written to. In the given directory a new + directory is created. Its name is + ``prediction--``, where + timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. + Inside of it files ``predictions_0001.``, + ``predictions_0002.``, ..., + ``predictions_N.`` are created where + ```` depends on chosen ``predictions_format``, + and N may equal 0001 and depends on the total number of + successfully predicted instances. If the Model has both + ``instance`` and ``prediction`` schemata defined then each such + file contains predictions as per the ``predictions_format``. + If prediction for any instance failed (partially or + completely), then an additional ``errors_0001.``, + ``errors_0002.``,..., ``errors_N.`` + files are created (N depends on total number of failed + predictions). These files contain the failed instances, as + per their schema, followed by an additional ``error`` field + which as value has ```google.rpc.Status`` `__ + containing only ``code`` and ``message`` fields. + bigquery_destination_prefix: Optional[str] = None + The BigQuery project location where the output is to be + written to. In the given project a new dataset is created + with name + ``prediction__`` where + is made BigQuery-dataset-name compatible (for example, most + special characters become underscores), and timestamp is in + YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the + dataset two tables will be created, ``predictions``, and + ``errors``. If the Model has both ``instance`` and ``prediction`` + schemata defined then the tables have columns as follows: + The ``predictions`` table contains instances for which the + prediction succeeded, it has columns as per a concatenation + of the Model's instance and prediction schemata. The + ``errors`` table contains rows for which the prediction has + failed, it has instance columns, as per the instance schema, + followed by a single "errors" column, which as values has + ```google.rpc.Status`` `__ represented as a STRUCT, + and containing only ``code`` and ``message``. + predictions_format: str = "jsonl" + Required. The format in which AI Platform gives the + predictions, must be one of "jsonl", "csv", or "bigquery". + Default is "jsonl" when using `gcs_destination_prefix`. If a + `bigquery_destination_prefix` is provided, this is overriden to + "bigquery". + model_parameters: Optional[Dict] = None + Optional. The parameters that govern the predictions. The schema of + the parameters may be specified via the Model's `parameters_schema_uri`. + machine_type: Optional[str] = None + Optional. The type of machine for running batch prediction on + dedicated resources. Not specifying machine type will result in + batch prediction job being run with automatic resources. + accelerator_type: Optional[str] = None + Optional. The type of accelerator(s) that may be attached + to the machine as per `accelerator_count`. Only used if + `machine_type` is set. + accelerator_count: Optional[int] = None + Optional. The number of accelerators to attach to the + `machine_type`. Only used if `machine_type` is set. + starting_replica_count: Optional[int] = None + The number of machine replicas used at the start of the batch + operation. If not set, AI Platform decides starting number, not + greater than `max_replica_count`. Only used if `machine_type` is + set. + max_replica_count: Optional[int] = None + The maximum number of machine replicas the batch operation may + be scaled to. Only used if `machine_type` is set. + Default is 10. + generate_explanation (bool): + Optional. Generate explanation along with the batch prediction + results. This will cause the batch prediction output to include + explanations based on the `prediction_format`: + - `bigquery`: output includes a column named `explanation`. The value + is a struct that conforms to the [aiplatform.gapic.Explanation] object. + - `jsonl`: The JSON objects on each line include an additional entry + keyed `explanation`. The value of the entry is a JSON object that + conforms to the [aiplatform.gapic.Explanation] object. + - `csv`: Generating explanations for CSV format is not supported. + explanation_metadata (explain.ExplanationMetadata): + Optional. Explanation metadata configuration for this BatchPredictionJob. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_metadata`. + All fields of `explanation_metadata` are optional in the request. If + a field of the `explanation_metadata` object is not populated, the + corresponding field of the `Model.explanation_metadata` object is inherited. + For more details, see `Ref docs ` + explanation_parameters (explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_parameters`. + All fields of `explanation_parameters` are optional in the request. If + a field of the `explanation_parameters` object is not populated, the + corresponding field of the `Model.explanation_parameters` object is inherited. + For more details, see `Ref docs ` + labels: Optional[dict] = None + Optional. The labels with user-defined metadata to organize your + BatchPredictionJobs. Label keys and values can be no longer than + 64 characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information and examples of labels. + credentials: Optional[auth_credentials.Credentials] = None + Optional. Custom credentials to use to create this batch prediction + job. Overrides credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + + """ + self.wait() + + return jobs.BatchPredictionJob.create( + job_display_name=job_display_name, + model_name=self.resource_name, + instances_format=instances_format, + predictions_format=predictions_format, + gcs_source=gcs_source, + bigquery_source=bigquery_source, + gcs_destination_prefix=gcs_destination_prefix, + bigquery_destination_prefix=bigquery_destination_prefix, + model_parameters=model_parameters, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + starting_replica_count=starting_replica_count, + max_replica_count=max_replica_count, + generate_explanation=generate_explanation, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + labels=labels, + project=self.project, + location=self.location, + credentials=credentials or self.credentials, + encryption_spec_key_name=encryption_spec_key_name, + sync=sync, + ) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["models.Model"]: + """List all Model resource instances. + + Example Usage: + + aiplatform.Model.list( + filter='labels.my_label="my_label_value" AND display_name="my_model"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[models.Model] - A list of Model resource objects + """ + + return cls._list( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/schema.py b/google/cloud/aiplatform/schema.py new file mode 100644 index 0000000000..04d2f026a1 --- /dev/null +++ b/google/cloud/aiplatform/schema.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Namespaced AI Platform Schemas.""" + + +class training_job: + class definition: + custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml" + automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml" + automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml" + automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml" + automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml" + automl_text_extraction = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml" + automl_text_sentiment = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_sentiment_1.0.0.yaml" + automl_video_action_recognition = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_action_recognition_1.0.0.yaml" + automl_video_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_classification_1.0.0.yaml" + automl_video_object_tracking = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml" + + +class dataset: + class metadata: + tabular = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/tabular_1.0.0.yaml" + ) + image = "gs://google-cloud-aiplatform/schema/dataset/metadata/image_1.0.0.yaml" + text = "gs://google-cloud-aiplatform/schema/dataset/metadata/text_1.0.0.yaml" + video = "gs://google-cloud-aiplatform/schema/dataset/metadata/video_1.0.0.yaml" + + class ioformat: + class image: + multi_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_classification_multi_label_io_format_1.0.0.yaml" + single_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_classification_single_label_io_format_1.0.0.yaml" + bounding_box = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" + image_segmentation = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_segmentation_io_format_1.0.0.yaml" + + class text: + multi_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_classification_multi_label_io_format_1.0.0.yaml" + single_label_classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_classification_single_label_io_format_1.0.0.yaml" + extraction = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml" + sentiment = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_sentiment_io_format_1.0.0.yaml" + + class video: + action_recognition = "gs://google-cloud-aiplatform/schema/dataset/ioformat/video_action_recognition_io_format_1.0.0.yaml" + classification = "gs://google-cloud-aiplatform/schema/dataset/ioformat/video_classification_io_format_1.0.0.yaml" + object_tracking = "gs://google-cloud-aiplatform/schema/dataset/ioformat/video_object_tracking_io_format_1.0.0.yaml" + + class annotation: + class image: + classification = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml" + bounding_box = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_bounding_box_1.0.0.yaml" + segmentation = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_segmentation_1.0.0.yaml" + + class text: + classification = "gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml" + extraction = "gs://google-cloud-aiplatform/schema/dataset/annotation/text_extraction_1.0.0.yaml" + sentiment = "gs://google-cloud-aiplatform/schema/dataset/annotation/text_sentiment_1.0.0.yaml" + + class video: + classification = "gs://google-cloud-aiplatform/schema/dataset/annotation/video_classification_1.0.0.yaml" + object_tracking = "gs://google-cloud-aiplatform/schema/dataset/annotation/video_object_tracking_1.0.0.yaml" + action_recognition = "gs://google-cloud-aiplatform/schema/dataset/annotation/video_action_recognition_1.0.0.yaml" diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py new file mode 100644 index 0000000000..220a34637e --- /dev/null +++ b/google/cloud/aiplatform/training_jobs.py @@ -0,0 +1,4362 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import functools +import logging +import pathlib +import shutil +import subprocess +import sys +import tempfile +import time +from typing import Callable, Dict, List, Optional, NamedTuple, Sequence, Tuple, Union + +import abc + +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform.compat.types import ( + accelerator_type as gca_accelerator_type, + env_var as gca_env_var, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +from google.cloud.aiplatform.v1.schema.trainingjob import ( + definition_v1 as training_job_inputs, +) + +from google.cloud import storage +from google.rpc import code_pb2 + +import proto + + +logging.basicConfig(level=logging.INFO, stream=sys.stdout) +_LOGGER = base.Logger(__name__) + +_PIPELINE_COMPLETE_STATES = set( + [ + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + gca_pipeline_state.PipelineState.PIPELINE_STATE_CANCELLED, + gca_pipeline_state.PipelineState.PIPELINE_STATE_PAUSED, + ] +) + + +class _TrainingJob(base.AiPlatformResourceNounWithFutureManager): + + client_class = utils.PipelineClientWithOverride + _is_client_prediction_client = False + _resource_noun = "trainingPipelines" + _getter_method = "get_training_pipeline" + _list_method = "list_training_pipelines" + _delete_method = "delete_training_pipeline" + + def __init__( + self, + display_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + project (str): + Optional project to retrieve model from. If not set, project set in + aiplatform.init will be used. + location (str): + Optional location to retrieve model from. If not set, location set in + aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional credentials to use to retrieve the model. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + """ + utils.validate_display_name(display_name) + + super().__init__(project=project, location=location, credentials=credentials) + self._display_name = display_name + self._project = project + self._training_encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=training_encryption_spec_key_name + ) + self._model_encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=model_encryption_spec_key_name + ) + self._gca_resource = None + + @property + @classmethod + @abc.abstractmethod + def _supported_training_schemas(cls) -> Tuple[str]: + """List of supported schemas for this training job""" + + pass + + @classmethod + def get( + cls, + resource_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "_TrainingJob": + """Get Training Job for the given resource_name. + + Args: + resource_name (str): + Required. A fully-qualified resource name or ID. + project (str): + Optional project to retrieve dataset from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve dataset from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Custom credentials to use to upload this model. Overrides + credentials set in aiplatform.init. + + Raises: + ValueError: If the retrieved training job's training task definition + doesn't match the custom training task definition. + + Returns: + An AI Platform Training Job + """ + + # Create job with dummy parameters + # These parameters won't be used as user can not run the job again. + # If they try, an exception will be raised. + self = cls._empty_constructor( + project=project, + location=location, + credentials=credentials, + resource_name=resource_name, + ) + + self._gca_resource = self._get_gca_resource(resource_name=resource_name) + + if ( + self._gca_resource.training_task_definition + not in cls._supported_training_schemas + ): + raise ValueError( + f"The retrieved job's training task definition " + f"is {self._gca_resource.training_task_definition}, " + f"which is not compatible with {cls.__name__}." + ) + + return self + + @property + @abc.abstractmethod + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + + pass + + @abc.abstractmethod + def run(self) -> Optional[models.Model]: + """Runs the training job. Should call _run_job internally""" + pass + + @staticmethod + def _create_input_data_config( + dataset: Optional[datasets._Dataset] = None, + annotation_schema_uri: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + gcs_destination_uri_prefix: Optional[str] = None, + bigquery_destination: Optional[str] = None, + ) -> Optional[gca_training_pipeline.InputDataConfig]: + """Constructs a input data config to pass to the training pipeline. + + Args: + dataset (datasets._Dataset): + The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + gcs_destination_uri_prefix (str): + Optional. The Google Cloud Storage location. + + The AI Platform environment variables representing Google + Cloud Storage data URIs will always be represented in the + Google Cloud Storage wildcard format to support sharded + data. + + - AIP_DATA_FORMAT = "jsonl". + - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" + - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" + - AIP_TEST_DATA_URI = "gcs_destination/test-*". + bigquery_destination (str): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + """ + + input_data_config = None + if dataset: + # Create fraction split spec + fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + ) + + # Create predefined split spec + predefined_split = None + if predefined_split_column_name: + if ( + dataset._gca_resource.metadata_schema_uri + != schema.dataset.metadata.tabular + ): + raise ValueError( + "A pre-defined split may only be used with a tabular Dataset" + ) + + predefined_split = gca_training_pipeline.PredefinedSplit( + key=predefined_split_column_name + ) + + # Create GCS destination + gcs_destination = None + if gcs_destination_uri_prefix: + gcs_destination = gca_io.GcsDestination( + output_uri_prefix=gcs_destination_uri_prefix + ) + + # TODO(b/177416223) validate managed BQ dataset is passed in + bigquery_destination_proto = None + if bigquery_destination: + bigquery_destination_proto = gca_io.BigQueryDestination( + output_uri=bigquery_destination + ) + + # create input data config + input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=fraction_split, + predefined_split=predefined_split, + dataset_id=dataset.name, + annotation_schema_uri=annotation_schema_uri, + gcs_destination=gcs_destination, + bigquery_destination=bigquery_destination_proto, + ) + + return input_data_config + + def _run_job( + self, + training_task_definition: str, + training_task_inputs: Union[dict, proto.Message], + dataset: Optional[datasets._Dataset], + training_fraction_split: float, + validation_fraction_split: float, + test_fraction_split: float, + annotation_schema_uri: Optional[str] = None, + predefined_split_column_name: Optional[str] = None, + model: Optional[gca_model.Model] = None, + gcs_destination_uri_prefix: Optional[str] = None, + bigquery_destination: Optional[str] = None, + ) -> Optional[models.Model]: + """Runs the training job. + + Args: + training_task_definition (str): + Required. A Google Cloud Storage path to the + YAML file that defines the training task which + is responsible for producing the model artifact, + and may also include additional auxiliary work. + The definition files that can be used here are + found in gs://google-cloud- + aiplatform/schema/trainingjob/definition/. Note: + The URI given on output will be immutable and + probably different, including the URI scheme, + than the one given on input. The output URI will + point to a location where the user only has a + read access. + training_task_inputs (Union[dict, proto.Message]): + Required. The training task's input that corresponds to the training_task_definition parameter. + dataset (datasets._Dataset): + The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + model (~.model.Model): + Optional. Describes the Model that may be uploaded (via + [ModelService.UploadMode][]) by this TrainingPipeline. The + TrainingPipeline's + ``training_task_definition`` + should make clear whether this Model description should be + populated, and if there are any special requirements + regarding how it should be filled. If nothing is mentioned + in the + ``training_task_definition``, + then it should be assumed that this field should not be + filled and the training task either uploads the Model + without a need of this information, or that training task + does not support uploading a Model as part of the pipeline. + When the Pipeline's state becomes + ``PIPELINE_STATE_SUCCEEDED`` and the trained Model had been + uploaded into AI Platform, then the model_to_upload's + resource ``name`` + is populated. The Model is always uploaded into the Project + and Location in which this pipeline is. + gcs_destination_uri_prefix (str): + Optional. The Google Cloud Storage location. + + The AI Platform environment variables representing Google + Cloud Storage data URIs will always be represented in the + Google Cloud Storage wildcard format to support sharded + data. + + - AIP_DATA_FORMAT = "jsonl". + - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" + - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" + - AIP_TEST_DATA_URI = "gcs_destination/test-*". + bigquery_destination (str): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + """ + + input_data_config = self._create_input_data_config( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + gcs_destination_uri_prefix=gcs_destination_uri_prefix, + bigquery_destination=bigquery_destination, + ) + + # create training pipeline + training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=self._display_name, + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs, + model_to_upload=model, + input_data_config=input_data_config, + encryption_spec=self._training_encryption_spec, + ) + + training_pipeline = self.api_client.create_training_pipeline( + parent=initializer.global_config.common_location_path( + self.project, self.location + ), + training_pipeline=training_pipeline, + ) + + self._gca_resource = training_pipeline + + _LOGGER.info("View Training:\n%s" % self._dashboard_uri()) + + model = self._get_model() + + if model is None: + _LOGGER.warning( + "Training did not produce a Managed Model returning None. " + + self._model_upload_fail_string + ) + + return model + + def _is_waiting_to_run(self) -> bool: + """Returns True if the Job is pending on upstream tasks False otherwise.""" + self._raise_future_exception() + if self._latest_future: + _LOGGER.info( + "Training Job is waiting for upstream SDK tasks to complete before" + " launching." + ) + return True + return False + + @property + def state(self) -> Optional[gca_pipeline_state.PipelineState]: + """Current training state.""" + + if self._assert_has_run(): + return + + self._sync_gca_resource() + return self._gca_resource.state + + def get_model(self, sync=True) -> models.Model: + """AI Platform Model produced by this training, if one was produced. + + Args: + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: AI Platform Model produced by this training + + Raises: + RuntimeError if training failed or if a model was not produced by this training. + """ + + self._assert_has_run() + if not self._gca_resource.model_to_upload: + raise RuntimeError(self._model_upload_fail_string) + + return self._force_get_model(sync=sync) + + @base.optional_sync() + def _force_get_model(self, sync: bool = True) -> models.Model: + """AI Platform Model produced by this training, if one was produced. + + Args: + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: AI Platform Model produced by this training + + Raises: + RuntimeError if training failed or if a model was not produced by this training. + """ + model = self._get_model() + + if model is None: + raise RuntimeError(self._model_upload_fail_string) + + return model + + def _get_model(self) -> Optional[models.Model]: + """Helper method to get and instantiate the Model to Upload. + + Returns: + model: AI Platform Model if training succeeded and produced an AI Platform + Model. None otherwise. + + Raises: + RuntimeError if Training failed. + """ + self._block_until_complete() + + if self.has_failed: + raise RuntimeError( + f"Training Pipeline {self.resource_name} failed. No model available." + ) + + if not self._gca_resource.model_to_upload: + return None + + if self._gca_resource.model_to_upload.name: + fields = utils.extract_fields_from_resource_name( + self._gca_resource.model_to_upload.name + ) + + return models.Model( + fields.id, project=fields.project, location=fields.location, + ) + + def _block_until_complete(self): + """Helper method to block and check on job until complete.""" + + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + log_wait = 5 + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + previous_time = time.time() + while self.state not in _PIPELINE_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + time.sleep(wait) + + self._raise_failure() + + _LOGGER.log_action_completed_against_resource("run", "completed", self) + + if self._gca_resource.model_to_upload and not self.has_failed: + _LOGGER.info( + "Model available at %s" % self._gca_resource.model_to_upload.name + ) + + def _raise_failure(self): + """Helper method to raise failure if TrainingPipeline fails. + + Raises: + RuntimeError: If training failed.""" + + if self._gca_resource.error.code != code_pb2.OK: + raise RuntimeError("Training failed with:\n%s" % self._gca_resource.error) + + @property + def has_failed(self) -> bool: + """Returns True if training has failed. False otherwise.""" + self._assert_has_run() + return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED + + def _dashboard_uri(self) -> str: + """Helper method to compose the dashboard uri where training can be viewed.""" + fields = utils.extract_fields_from_resource_name(self.resource_name) + url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/training/{fields.id}?project={fields.project}" + return url + + def _sync_gca_resource(self): + """Helper method to sync the local gca_source against the service.""" + self._gca_resource = self.api_client.get_training_pipeline( + name=self.resource_name + ) + + @property + def _has_run(self) -> bool: + """Helper property to check if this training job has been run.""" + return self._gca_resource is not None + + def _assert_has_run(self) -> bool: + """Helper method to assert that this training has run.""" + if not self._has_run: + if self._is_waiting_to_run(): + return True + raise RuntimeError( + "TrainingPipeline has not been launched. You must run this" + " TrainingPipeline using TrainingPipeline.run. " + ) + return False + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["base.AiPlatformResourceNoune"]: + """List all instances of this TrainingJob resource. + + Example Usage: + + aiplatform.CustomTrainingJob.list( + filter='display_name="experiment_a27"', + order_by='create_time desc' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of TrainingJob resource objects + """ + + training_job_subclass_filter = ( + lambda gapic_obj: gapic_obj.training_task_definition + in cls._supported_training_schemas + ) + + return cls._list_with_local_order( + cls_filter=training_job_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + + def cancel(self) -> None: + """Starts asynchronous cancellation on the TrainingJob. The server + makes a best effort to cancel the job, but success is not guaranteed. + On successful cancellation, the TrainingJob is not deleted; instead it + becomes a job with state set to `CANCELLED`. + + Raises: + RuntimeError if this TrainingJob has not started running. + """ + if not self._has_run: + raise RuntimeError( + "This TrainingJob has not been launched, use the `run()` method " + "to start. `cancel()` can only be called on a job that is running." + ) + self.api_client.cancel_training_pipeline(name=self.resource_name) + + +def _timestamped_gcs_dir(root_gcs_path: str, dir_name_prefix: str) -> str: + """Composes a timestamped GCS directory. + + Args: + root_gcs_path: GCS path to put the timestamped directory. + dir_name_prefix: Prefix to add the timestamped directory. + Returns: + Timestamped gcs directory path in root_gcs_path. + """ + timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds") + dir_name = "-".join([dir_name_prefix, timestamp]) + if root_gcs_path.endswith("/"): + root_gcs_path = root_gcs_path[:-1] + gcs_path = "/".join([root_gcs_path, dir_name]) + if not gcs_path.startswith("gs://"): + return "gs://" + gcs_path + return gcs_path + + +def _timestamped_copy_to_gcs( + local_file_path: str, + gcs_dir: str, + project: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, +) -> str: + """Copies a local file to a GCS path. + + The file copied to GCS is the name of the local file prepended with an + "aiplatform-{timestamp}-" string. + + Args: + local_file_path (str): Required. Local file to copy to GCS. + gcs_dir (str): + Required. The GCS directory to copy to. + project (str): + Project that contains the staging bucket. Default will be used if not + provided. Model Builder callers should pass this in. + credentials (auth_credentials.Credentials): + Custom credentials to use with bucket. Model Builder callers should pass + this in. + Returns: + gcs_path (str): The path of the copied file in gcs. + """ + + gcs_bucket, gcs_blob_prefix = utils.extract_bucket_and_prefix_from_gcs_path(gcs_dir) + + local_file_name = pathlib.Path(local_file_path).name + timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds") + blob_path = "-".join(["aiplatform", timestamp, local_file_name]) + + if gcs_blob_prefix: + blob_path = "/".join([gcs_blob_prefix, blob_path]) + + # TODO(b/171202993) add user agent + client = storage.Client(project=project, credentials=credentials) + bucket = client.bucket(gcs_bucket) + blob = bucket.blob(blob_path) + blob.upload_from_filename(local_file_path) + + gcs_path = "".join(["gs://", "/".join([blob.bucket.name, blob.name])]) + return gcs_path + + +def _get_python_executable() -> str: + """Returns Python executable. + + Raises: + EnvironmentError if Python executable is not found. + Returns: + Python executable to use for setuptools packaging. + """ + + python_executable = sys.executable + + if not python_executable: + raise EnvironmentError("Cannot find Python executable for packaging.") + return python_executable + + +class _TrainingScriptPythonPackager: + """Converts a Python script into Python package suitable for aiplatform training. + + Copies the script to specified location. + + Class Attributes: + _TRAINER_FOLDER: Constant folder name to build package. + _ROOT_MODULE: Constant root name of module. + _TEST_MODULE_NAME: Constant name of module that will store script. + _SETUP_PY_VERSION: Constant version of this created python package. + _SETUP_PY_TEMPLATE: Constant template used to generate setup.py file. + _SETUP_PY_SOURCE_DISTRIBUTION_CMD: + Constant command to generate the source distribution package. + + Attributes: + script_path: local path of script to package + requirements: list of Python dependencies to add to package + + Usage: + + packager = TrainingScriptPythonPackager('my_script.py', ['pandas', 'pytorch']) + gcs_path = packager.package_and_copy_to_gcs( + gcs_staging_dir='my-bucket', + project='my-prject') + module_name = packager.module_name + + The package after installed can be executed as: + python -m aiplatform_custom_trainer_script.task + + """ + + _TRAINER_FOLDER = "trainer" + _ROOT_MODULE = "aiplatform_custom_trainer_script" + _TASK_MODULE_NAME = "task" + _SETUP_PY_VERSION = "0.1" + + _SETUP_PY_TEMPLATE = """from setuptools import find_packages +from setuptools import setup + +setup( + name='{name}', + version='{version}', + packages=find_packages(), + install_requires=({requirements}), + include_package_data=True, + description='My training application.' +)""" + + _SETUP_PY_SOURCE_DISTRIBUTION_CMD = "setup.py sdist --formats=gztar" + + # Module name that can be executed during training. ie. python -m + module_name = f"{_ROOT_MODULE}.{_TASK_MODULE_NAME}" + + def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = None): + """Initializes packager. + + Args: + script_path (str): Required. Local path to script. + requirements (Sequence[str]): + List of python packages dependencies of script. + """ + + self.script_path = script_path + self.requirements = requirements or [] + + def make_package(self, package_directory: str) -> str: + """Converts script into a Python package suitable for python module execution. + + Args: + package_directory (str): Directory to build package in. + Returns: + source_distribution_path (str): Path to built package. + Raises: + RunTimeError if package creation fails. + """ + # The root folder to builder the package in + package_path = pathlib.Path(package_directory) + + # Root directory of the package + trainer_root_path = package_path / self._TRAINER_FOLDER + + # The root module of the python package + trainer_path = trainer_root_path / self._ROOT_MODULE + + # __init__.py path in root module + init_path = trainer_path / "__init__.py" + + # The module that will contain the script + script_out_path = trainer_path / f"{self._TASK_MODULE_NAME}.py" + + # The path to setup.py in the package. + setup_py_path = trainer_root_path / "setup.py" + + # The path to the generated source distribution. + source_distribution_path = ( + trainer_root_path + / "dist" + / f"{self._ROOT_MODULE}-{self._SETUP_PY_VERSION}.tar.gz" + ) + + trainer_root_path.mkdir() + trainer_path.mkdir() + + # Make empty __init__.py + with init_path.open("w"): + pass + + # Format the setup.py file. + setup_py_output = self._SETUP_PY_TEMPLATE.format( + name=self._ROOT_MODULE, + requirements=",".join(f'"{r}"' for r in self.requirements), + version=self._SETUP_PY_VERSION, + ) + + # Write setup.py + with setup_py_path.open("w") as fp: + fp.write(setup_py_output) + + # Copy script as module of python package. + shutil.copy(self.script_path, script_out_path) + + # Run setup.py to create the source distribution. + setup_cmd = [ + _get_python_executable() + ] + self._SETUP_PY_SOURCE_DISTRIBUTION_CMD.split() + + p = subprocess.Popen( + args=setup_cmd, + cwd=trainer_root_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + output, error = p.communicate() + + # Raise informative error if packaging fails. + if p.returncode != 0: + raise RuntimeError( + "Packaging of training script failed with code %d\n%s \n%s" + % (p.returncode, output.decode(), error.decode()) + ) + + return str(source_distribution_path) + + def package_and_copy(self, copy_method: Callable[[str], str]) -> str: + """Packages the script and executes copy with given copy_method. + + Args: + copy_method Callable[[str], str] + Takes a string path, copies to a desired location, and returns the + output path location. + Returns: + output_path str: Location of copied package. + """ + + with tempfile.TemporaryDirectory() as tmpdirname: + source_distribution_path = self.make_package(tmpdirname) + output_location = copy_method(source_distribution_path) + _LOGGER.info("Training script copied to:\n%s." % output_location) + return output_location + + def package_and_copy_to_gcs( + self, + gcs_staging_dir: str, + project: str = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> str: + """Packages script in Python package and copies package to GCS bucket. + + Args + gcs_staging_dir (str): Required. GCS Staging directory. + project (str): Required. Project where GCS Staging bucket is located. + credentials (auth_credentials.Credentials): + Optional credentials used with GCS client. + Returns: + GCS location of Python package. + """ + + copy_method = functools.partial( + _timestamped_copy_to_gcs, + gcs_dir=gcs_staging_dir, + project=project, + credentials=credentials, + ) + return self.package_and_copy(copy_method=copy_method) + + +class _MachineSpec(NamedTuple): + """Specification container for Machine specs used for distributed training. + + Usage: + + spec = _MachineSpec( + replica_count=10, + machine_type='n1-standard-4', + accelerator_count=2, + accelerator_type='NVIDIA_TESLA_K80') + + Note that container and python package specs are not stored with this spec. + """ + + replica_count: int = 0 + machine_type: str = "n1-standard-4" + accelerator_count: int = 0 + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED" + + def _get_accelerator_type(self) -> Optional[str]: + """Validates accelerator_type and returns the name of the accelerator. + + Returns: + None if no accelerator or valid accelerator name. + + Raise: + ValueError if accelerator type is invalid. + """ + + # Raises ValueError if invalid accelerator_type + utils.validate_accelerator_type(self.accelerator_type) + + accelerator_enum = getattr( + gca_accelerator_type.AcceleratorType, self.accelerator_type + ) + + if ( + accelerator_enum + != gca_accelerator_type.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED + ): + return self.accelerator_type + + @property + def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]: + """Return specification as a Dict.""" + spec = { + "machineSpec": {"machineType": self.machine_type}, + "replicaCount": self.replica_count, + } + accelerator_type = self._get_accelerator_type() + if accelerator_type and self.accelerator_count: + spec["machineSpec"]["acceleratorType"] = accelerator_type + spec["machineSpec"]["acceleratorCount"] = self.accelerator_count + + return spec + + @property + def is_empty(self) -> bool: + """Returns True is replica_count > 0 False otherwise.""" + return self.replica_count <= 0 + + +class _DistributedTrainingSpec(NamedTuple): + """Configuration for distributed training worker pool specs. + + AI Platform Training expects configuration in this order: + [ + chief spec, # can only have one replica + worker spec, + parameter server spec, + evaluator spec + ] + + Usage: + + dist_training_spec = _DistributedTrainingSpec( + chief_spec = _MachineSpec( + replica_count=1, + machine_type='n1-standard-4', + accelerator_count=2, + accelerator_type='NVIDIA_TESLA_K80' + ), + worker_spec = _MachineSpec( + replica_count=10, + machine_type='n1-standard-4', + accelerator_count=2, + accelerator_type='NVIDIA_TESLA_K80' + ) + ) + + """ + + chief_spec: _MachineSpec = _MachineSpec() + worker_spec: _MachineSpec = _MachineSpec() + parameter_server_spec: _MachineSpec = _MachineSpec() + evaluator_spec: _MachineSpec = _MachineSpec() + + @property + def pool_specs( + self, + ) -> List[Dict[str, Union[int, str, Dict[str, Union[int, str]]]]]: + """Return each pools spec in correct order for AI Platform as a list of dicts. + + Also removes specs if they are empty but leaves specs in if there unusual + specifications to not break the ordering in AI Platform Training. + ie. 0 chief replica, 10 worker replica, 3 ps replica + + Returns: + Order list of worker pool specs suitable for AI Platform Training. + """ + if self.chief_spec.replica_count > 1: + raise ValueError("Chief spec replica count cannot be greater than 1.") + + spec_order = [ + self.chief_spec, + self.worker_spec, + self.parameter_server_spec, + self.evaluator_spec, + ] + specs = [s.spec_dict for s in spec_order] + for i in reversed(range(len(spec_order))): + if spec_order[i].is_empty: + specs.pop() + else: + break + return specs + + @classmethod + def chief_worker_pool( + cls, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_count: int = 0, + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + ) -> "_DistributedTrainingSpec": + """Parameterizes Config to support only chief with worker replicas. + + For replica is assigned to chief and the remainder to workers. All spec have the + same machine type, accelerator count, and accelerator type. + + Args: + replica_count (int): + The number of worker replicas. Assigns 1 chief replica and + replica_count - 1 worker replicas. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + + Returns: + _DistributedTrainingSpec representing one chief and n workers all of same + type. If replica_count <= 0 then an empty spec is returned. + """ + if replica_count <= 0: + return cls() + + chief_spec = _MachineSpec( + replica_count=1, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + worker_spec = _MachineSpec( + replica_count=replica_count - 1, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + return cls(chief_spec=chief_spec, worker_spec=worker_spec) + + +class _CustomTrainingJob(_TrainingJob): + """ABC for Custom Training Pipelines.. + """ + + _supported_training_schemas = (schema.training_job.definition.custom_task,) + + def __init__( + self, + display_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """ + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + container_uri (str): + Required: Uri of the training container image in the GCR. + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + self._container_uri = container_uri + + model_predict_schemata = None + if any( + [ + model_instance_schema_uri, + model_parameters_schema_uri, + model_prediction_schema_uri, + ] + ): + model_predict_schemata = gca_model.PredictSchemata( + instance_schema_uri=model_instance_schema_uri, + parameters_schema_uri=model_parameters_schema_uri, + prediction_schema_uri=model_prediction_schema_uri, + ) + + # Create the container spec + env = None + ports = None + + if model_serving_container_environment_variables: + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in model_serving_container_environment_variables.items() + ] + + if model_serving_container_ports: + ports = [ + gca_model.Port(container_port=port) + for port in model_serving_container_ports + ] + + container_spec = gca_model.ModelContainerSpec( + image_uri=model_serving_container_image_uri, + command=model_serving_container_command, + args=model_serving_container_args, + env=env, + ports=ports, + predict_route=model_serving_container_predict_route, + health_route=model_serving_container_health_route, + ) + + # create model payload + self._managed_model = gca_model.Model( + description=model_description, + predict_schemata=model_predict_schemata, + container_spec=container_spec, + encryption_spec=self._model_encryption_spec, + ) + + self._staging_bucket = ( + staging_bucket or initializer.global_config.staging_bucket + ) + + if not self._staging_bucket: + raise RuntimeError( + "staging_bucket should be set in TrainingJob constructor or " + "set using aiplatform.init(staging_bucket='gs://my-bucket')" + ) + + def _prepare_and_validate_run( + self, + model_display_name: Optional[str] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + ) -> Tuple[_DistributedTrainingSpec, Optional[gca_model.Model]]: + """Create worker pool specs and managed model as well validating the run. + + Args: + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + Returns: + Worker pools specs and managed model for run. + + Raises: + RuntimeError if Training job has already been run or model_display_name was + provided but required arguments were not provided in constructor. + + """ + + if self._is_waiting_to_run(): + raise RuntimeError("Custom Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("Custom Training has already run.") + + # if args needed for model is incomplete + if model_display_name and not self._managed_model.container_spec.image_uri: + raise RuntimeError( + """model_display_name was provided but + model_serving_container_image_uri was not provided when this + custom pipeline was constructed. + """ + ) + + # validates args and will raise + worker_pool_specs = _DistributedTrainingSpec.chief_worker_pool( + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ).pool_specs + + managed_model = self._managed_model + if model_display_name: + utils.validate_display_name(model_display_name) + managed_model.display_name = model_display_name + else: + managed_model = None + + return worker_pool_specs, managed_model + + def _prepare_training_task_inputs_and_output_dir( + self, + worker_pool_specs: _DistributedTrainingSpec, + base_output_dir: Optional[str] = None, + ) -> Tuple[Dict, str]: + """Prepares training task inputs and output directory for custom job. + + Args: + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + Returns: + Training task inputs and Output directory for custom job. + """ + + # default directory if not given + base_output_dir = base_output_dir or _timestamped_gcs_dir( + self._staging_bucket, "aiplatform-custom-training" + ) + + _LOGGER.info("Training Output directory:\n%s " % base_output_dir) + + training_task_inputs = { + "workerPoolSpecs": worker_pool_specs, + "baseOutputDirectory": {"output_uri_prefix": base_output_dir}, + } + + return training_task_inputs, base_output_dir + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"Training Pipeline {self.resource_name} is not configured to upload a " + "Model. Create the Training Pipeline with " + "model_serving_container_image_uri and model_display_name passed in. " + "Ensure that your training script saves to model to " + "os.environ['AIP_MODEL_DIR']." + ) + + +# TODO(b/172368325) add scheduling, custom_job.Scheduling +class CustomTrainingJob(_CustomTrainingJob): + """Class to launch a Custom Training Job in AI Platform using a script. + + Takes a training implementation as a python script and executes that script + in Cloud AI Platform Training. + """ + + def __init__( + self, + display_name: str, + script_path: str, + container_uri: str, + requirements: Optional[Sequence[str]] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Constructs a Custom Training Job from a Python script. + + job = aiplatform.CustomTrainingJob( + display_name='test-train', + script_path='test_script.py', + requirements=['pandas', 'numpy'], + container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest', + model_serving_container_image_uri='gcr.io/my-trainer/serving:1', + model_serving_container_predict_route='predict', + model_serving_container_health_route='metadata) + + Usage with Dataset: + + ds = aiplatform.TabularDataset( + 'projects/my-project/locations/us-central1/datasets/12345') + + job.run(ds, replica_count=1, model_display_name='my-trained-model') + + Usage without Dataset: + + job.run(replica_count=1, model_display_name='my-trained-model) + + + TODO(b/169782082) add documentation about traning utilities + To ensure your model gets saved in AI Platform, write your saved model to + os.environ["AIP_MODEL_DIR"] in your provided training script. + + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + script_path (str): Required. Local path to training script. + container_uri (str): + Required: Uri of the training container image in the GCR. + requirements (Sequence[str]): + List of python packages dependencies of script. + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + container_uri=container_uri, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_description=model_description, + staging_bucket=staging_bucket, + ) + + self._requirements = requirements + self._script_path = script_path + + # TODO(b/172365904) add filter split, training_pipeline.FilterSplit + # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit + def run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Runs the custom training job. + + Distributed Training Support: + If replica count = 1 then one chief replica will be provisioned. If + replica_count > 1 the remainder will be provisioned as a worker replica pool. + ie: replica_count = 10 will result in 1 chief and 9 workers + All replicas have same machine_type, accelerator_type, and accelerator_count + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform.If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. Custom training script should + retrieve datasets through passed in environment variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + worker_pool_specs, managed_model = self._prepare_and_validate_run( + model_display_name=model_display_name, + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + # make and copy package + python_packager = _TrainingScriptPythonPackager( + script_path=self._script_path, requirements=self._requirements + ) + + return self._run( + python_packager=python_packager, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + worker_pool_specs=worker_pool_specs, + managed_model=managed_model, + args=args, + base_output_dir=base_output_dir, + bigquery_destination=bigquery_destination, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + sync=sync, + ) + + @base.optional_sync(construct_object_on_arg="managed_model") + def _run( + self, + python_packager: _TrainingScriptPythonPackager, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], + annotation_schema_uri: Optional[str], + worker_pool_specs: _DistributedTrainingSpec, + managed_model: Optional[gca_model.Model] = None, + args: Optional[List[Union[str, float, int]]] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Packages local script and launches training_job. + + Args: + python_packager (_TrainingScriptPythonPackager): + Required. Python Packager pointing to training script locally. + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + managed_model (gca_model.Model): + Model proto if this script produces a Managed Model. + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + package_gcs_uri = python_packager.package_and_copy_to_gcs( + gcs_staging_dir=self._staging_bucket, + project=self.project, + credentials=self.credentials, + ) + + for spec in worker_pool_specs: + spec["pythonPackageSpec"] = { + "executorImageUri": self._container_uri, + "pythonModule": python_packager.module_name, + "packageUris": [package_gcs_uri], + } + + if args: + spec["pythonPackageSpec"]["args"] = args + + ( + training_task_inputs, + base_output_dir, + ) = self._prepare_training_task_inputs_and_output_dir( + worker_pool_specs, base_output_dir + ) + + model = self._run_job( + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=training_task_inputs, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=managed_model, + gcs_destination_uri_prefix=base_output_dir, + bigquery_destination=bigquery_destination, + ) + + return model + + +class CustomContainerTrainingJob(_CustomTrainingJob): + """Class to launch a Custom Training Job in AI Platform using a Container.""" + + def __init__( + self, + display_name: str, + container_uri: str, + command: Sequence[str] = None, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Constructs a Custom Container Training Job. + + job = aiplatform.CustomTrainingJob( + display_name='test-train', + container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest', + command=['python3', 'run_script.py'] + model_serving_container_image_uri='gcr.io/my-trainer/serving:1', + model_serving_container_predict_route='predict', + model_serving_container_health_route='metadata) + + Usage with Dataset: + + ds = aiplatform.TabularDataset( + 'projects/my-project/locations/us-central1/datasets/12345') + + job.run(ds, replica_count=1, model_display_name='my-trained-model') + + Usage without Dataset: + + job.run(replica_count=1, model_display_name='my-trained-model) + + + TODO(b/169782082) add documentation about traning utilities + To ensure your model gets saved in AI Platform, write your saved model to + os.environ["AIP_MODEL_DIR"] in your provided training script. + + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + container_uri (str): + Required: Uri of the training container image in the GCR. + command (Sequence[str]): + The command to be invoked when the container is started. + It overrides the entrypoint instruction in Dockerfile when provided + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + container_uri=container_uri, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_description=model_description, + staging_bucket=staging_bucket, + ) + + self._command = command + + # TODO(b/172365904) add filter split, training_pipeline.FilterSplit + # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit + def run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Runs the custom training job. + + Distributed Training Support: + If replica count = 1 then one chief replica will be provisioned. If + replica_count > 1 the remainder will be provisioned as a worker replica pool. + ie: replica_count = 10 will result in 1 chief and 9 workers + All replicas have same machine_type, accelerator_type, and accelerator_count + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. Custom training script should + retrieve datasets through passed in environment variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError if Training job has already been run, staging_bucket has not + been set, or model_display_name was provided but required arguments + were not provided in constructor. + """ + worker_pool_specs, managed_model = self._prepare_and_validate_run( + model_display_name=model_display_name, + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + return self._run( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + worker_pool_specs=worker_pool_specs, + managed_model=managed_model, + args=args, + base_output_dir=base_output_dir, + bigquery_destination=bigquery_destination, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + sync=sync, + ) + + @base.optional_sync(construct_object_on_arg="managed_model") + def _run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], + annotation_schema_uri: Optional[str], + worker_pool_specs: _DistributedTrainingSpec, + managed_model: Optional[gca_model.Model] = None, + args: Optional[List[Union[str, float, int]]] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Packages local script and launches training_job. + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + managed_model (gca_model.Model): + Model proto if this script produces a Managed Model. + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + for spec in worker_pool_specs: + spec["containerSpec"] = {"imageUri": self._container_uri} + + if self._command: + spec["containerSpec"]["command"] = self._command + + if args: + spec["containerSpec"]["args"] = args + + ( + training_task_inputs, + base_output_dir, + ) = self._prepare_training_task_inputs_and_output_dir( + worker_pool_specs, base_output_dir + ) + + model = self._run_job( + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=training_task_inputs, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=managed_model, + gcs_destination_uri_prefix=base_output_dir, + bigquery_destination=bigquery_destination, + ) + + return model + + +class AutoMLTabularTrainingJob(_TrainingJob): + _supported_training_schemas = (schema.training_job.definition.automl_tabular,) + + def __init__( + self, + display_name: str, + optimization_prediction_type: str, + optimization_objective: Optional[str] = None, + column_transformations: Optional[Union[Dict, List[Dict]]] = None, + optimization_objective_recall_value: Optional[float] = None, + optimization_objective_precision_value: Optional[float] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Tabular Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + optimization_prediction_type (str): + The type of prediction the Model is to produce. + "classification" - Predict one out of multiple target values is + picked for each row. + "regression" - Predict a value based on its relation to other values. + This type is available only to columns that contain + semantically numeric values, i.e. integers or floating + point number, even if stored as e.g. strings. + + optimization_objective (str): + Optional. Objective function the Model is to be optimized towards. The training + task creates a Model that maximizes/minimizes the value of the objective + function over the validation set. + + The supported optimization objectives depend on the prediction type, and + in the case of classification also the number of distinct values in the + target column (two distint values -> binary, 3 or more distinct values + -> multi class). + If the field is not set, the default objective function is used. + + Classification (binary): + "maximize-au-roc" (default) - Maximize the area under the receiver + operating characteristic (ROC) curve. + "minimize-log-loss" - Minimize log loss. + "maximize-au-prc" - Maximize the area under the precision-recall curve. + "maximize-precision-at-recall" - Maximize precision for a specified + recall value. + "maximize-recall-at-precision" - Maximize recall for a specified + precision value. + + Classification (multi class): + "minimize-log-loss" (default) - Minimize log loss. + + Regression: + "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE). + "minimize-mae" - Minimize mean-absolute error (MAE). + "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE). + column_transformations (Optional[Union[Dict, List[Dict]]]): + Optional. Transformations to apply to the input columns (i.e. columns other + than the targetColumn). Each transformation may produce multiple + result values from the column's value, and all are used for training. + When creating transformation for BigQuery Struct column, the column + should be flattened using "." as the delimiter. + If an input column has no transformations on it, such a column is + ignored by the training, except for the targetColumn, which should have + no transformations defined on. + optimization_objective_recall_value (float): + Optional. Required when maximize-precision-at-recall optimizationObjective was + picked, represents the recall value at which the optimization is done. + + The minimum value is 0 and the maximum is 1.0. + optimization_objective_precision_value (float): + Optional. Required when maximize-recall-at-precision optimizationObjective was + picked, represents the precision value at which the optimization is + done. + + The minimum value is 0 and the maximum is 1.0. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + self._column_transformations = column_transformations + self._optimization_objective = optimization_objective + self._optimization_prediction_type = optimization_prediction_type + self._optimization_objective_recall_value = optimization_objective_recall_value + self._optimization_objective_precision_value = ( + optimization_objective_precision_value + ) + + def run( + self, + dataset: datasets.TabularDataset, + target_column: str, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TabularDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + disable_early_stopping (bool): + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError if Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Tabular Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Tabular Training has already run.") + + return self._run( + dataset=dataset, + target_column=target_column, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + weight_column=weight_column, + budget_milli_node_hours=budget_milli_node_hours, + model_display_name=model_display_name, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.TabularDataset, + target_column: str, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + weight_column: Optional[str] = None, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TabularDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + weight_column (str): + Optional. Name of the column that should be used as the weight column. + Higher values in this column give more importance to the row + during Model training. The column must have numeric values between 0 and + 10000 inclusively, and 0 value means that the row is ignored. + If the weight column field is not set, then all rows are assumed to have + equal weight of 1. + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + disable_early_stopping (bool): + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + training_task_definition = schema.training_job.definition.automl_tabular + + training_task_inputs_dict = { + # required inputs + "targetColumn": target_column, + "transformations": self._column_transformations, + "trainBudgetMilliNodeHours": budget_milli_node_hours, + # optional inputs + "weightColumnName": weight_column, + "disableEarlyStopping": disable_early_stopping, + "optimizationObjective": self._optimization_objective, + "predictionType": self._optimization_prediction_type, + "optimizationObjectiveRecallValue": self._optimization_objective_recall_value, + "optimizationObjectivePrecisionValue": self._optimization_objective_precision_value, + } + + if model_display_name is None: + model_display_name = self._display_name + + model = gca_model.Model( + display_name=model_display_name, + encryption_spec=self._model_encryption_spec, + ) + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=model, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"Training Pipeline {self.resource_name} is not configured to upload a " + "Model." + ) + + +class AutoMLImageTrainingJob(_TrainingJob): + _supported_training_schemas = ( + schema.training_job.definition.automl_image_classification, + schema.training_job.definition.automl_image_object_detection, + ) + + def __init__( + self, + display_name: str, + prediction_type: str = "classification", + multi_label: bool = False, + model_type: str = "CLOUD", + base_model: Optional[models.Model] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Image Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + prediction_type (str): + The type of prediction the Model is to produce, one of: + "classification" - Predict one out of multiple target values is + picked for each row. + "object_detection" - Predict a value based on its relation to other values. + This type is available only to columns that contain + semantically numeric values, i.e. integers or floating + point number, even if stored as e.g. strings. + multi_label: bool = False + Required. Default is False. + If false, a single-label (multi-class) Model will be trained + (i.e. assuming that for each image just up to one annotation may be + applicable). If true, a multi-label Model will be trained (i.e. + assuming that for each image multiple annotations may be applicable). + + This is only applicable for the "classification" prediction_type and + will be ignored otherwise. + model_type: str = "CLOUD" + Required. One of the following: + "CLOUD" - Default for Image Classification. + A Model best tailored to be used within Google Cloud, and + which cannot be exported. + "CLOUD_HIGH_ACCURACY_1" - Default for Image Object Detection. + A model best tailored to be used within Google Cloud, and + which cannot be exported. Expected to have a higher latency, + but should also have a higher prediction quality than other + cloud models. + "CLOUD_LOW_LATENCY_1" - A model best tailored to be used within + Google Cloud, and which cannot be exported. Expected to have a + low latency, but may have lower prediction quality than other + cloud models. + "MOBILE_TF_LOW_LATENCY_1" - A model that, in addition to being + available within Google Cloud, can also be exported as TensorFlow + or Core ML model and used on a mobile or edge device afterwards. + Expected to have low latency, but may have lower prediction + quality than other mobile models. + "MOBILE_TF_VERSATILE_1" - A model that, in addition to being + available within Google Cloud, can also be exported as TensorFlow + or Core ML model and used on a mobile or edge device with afterwards. + "MOBILE_TF_HIGH_ACCURACY_1" - A model that, in addition to being + available within Google Cloud, can also be exported as TensorFlow + or Core ML model and used on a mobile or edge device afterwards. + Expected to have a higher latency, but should also have a higher + prediction quality than other mobile models. + base_model: Optional[models.Model] = None + Optional. Only permitted for Image Classification models. + If it is specified, the new model will be trained based on the `base` model. + Otherwise, the new model will be trained from scratch. The `base` model + must be in the same Project and Location as the new Model to train, + and have the same model_type. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Raises: + ValueError: When an invalid prediction_type or model_type is provided. + """ + + valid_model_types = constants.AUTOML_IMAGE_PREDICTION_MODEL_TYPES.get( + prediction_type, None + ) + + if not valid_model_types: + raise ValueError( + f"'{prediction_type}' is not a supported prediction type for AutoML Image Training. " + f"Please choose one of: {tuple(constants.AUTOML_IMAGE_PREDICTION_MODEL_TYPES.keys())}." + ) + + # Override default model_type for object_detection + if model_type == "CLOUD" and prediction_type == "object_detection": + model_type = "CLOUD_HIGH_ACCURACY_1" + + if model_type not in valid_model_types: + raise ValueError( + f"'{model_type}' is not a supported model_type for prediction_type of '{prediction_type}'. " + f"Please choose one of: {tuple(valid_model_types)}" + ) + + if base_model and prediction_type != "classification": + raise ValueError( + "Training with a `base_model` is only supported in AutoML Image Classification. " + f"However '{prediction_type}' was provided as `prediction_type`." + ) + + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + self._model_type = model_type + self._prediction_type = prediction_type + self._multi_label = multi_label + self._base_model = base_model + + def run( + self, + dataset: datasets.ImageDataset, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the AutoML Image training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.ImageDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split: float = 0.8 + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + budget_milli_node_hours: int = 1000 + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If not provided upon creation, the job's display_name is used. + disable_early_stopping: bool = False + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync: bool = True + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError: If Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Image Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Image Training has already run.") + + return self._run( + dataset=dataset, + base_model=self._base_model, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + model_display_name=model_display_name, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.ImageDataset, + base_model: Optional[models.Model] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 1000, + model_display_name: Optional[str] = None, + disable_early_stopping: bool = False, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.ImageDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + base_model: Optional[models.Model] = None + Optional. Only permitted for Image Classification models. + If it is specified, the new model will be trained based on the `base` model. + Otherwise, the new model will be trained from scratch. The `base` model + must be in the same Project and Location as the new Model to train, + and have the same model_type. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + budget_milli_node_hours (int): + Optional. The train budget of creating this Model, expressed in milli node + hours i.e. 1,000 value in this field means 1 node hour. + The training cost of the model will not exceed this budget. The final + cost will be attempted to be close to the budget, though may end up + being (even) noticeably smaller - at the backend's discretion. This + especially may happen when further model training ceases to provide + any improvements. + If the budget is set to a value known to be insufficient to train a + Model for the given training set, the training won't be attempted and + will error. + The minimum value is 1000 and the maximum is 72000. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If a `base_model` was provided, the display_name in the + base_model will be overritten with this value. If not provided upon + creation, the job's display_name is used. + disable_early_stopping (bool): + Required. If true, the entire budget is used. This disables the early stopping + feature. By default, the early stopping feature is enabled, which means + that training might stop before the entire training budget has been + used, if further training does no longer brings significant improvement + to the model. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + # Retrieve the objective-specific training task schema based on prediction_type + training_task_definition = getattr( + schema.training_job.definition, f"automl_image_{self._prediction_type}" + ) + + training_task_inputs_dict = { + # required inputs + "modelType": self._model_type, + "budgetMilliNodeHours": budget_milli_node_hours, + # optional inputs + "disableEarlyStopping": disable_early_stopping, + } + + if self._prediction_type == "classification": + training_task_inputs_dict["multiLabel"] = self._multi_label + + # gca Model to be trained + model_tbt = gca_model.Model(encryption_spec=self._model_encryption_spec) + + model_tbt.display_name = model_display_name or self._display_name + + if base_model: + # Use provided base_model to pass to model_to_upload causing the + # description and labels from base_model to be passed onto the new model + model_tbt.description = getattr(base_model._gca_resource, "description") + model_tbt.labels = getattr(base_model._gca_resource, "labels") + + # Set ID of AI Platform Model to base this training job off of + training_task_inputs_dict["baseModelId"] = base_model.name + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + model=model_tbt, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"AutoML Image Training Pipeline {self.resource_name} is not " + "configured to upload a Model." + ) + + +class CustomPythonPackageTrainingJob(_CustomTrainingJob): + """Class to launch a Custom Training Job in AI Platform using a Python Package. + + Takes a training implementation as a python package and executes that package + in Cloud AI Platform Training. + """ + + def __init__( + self, + display_name: str, + python_package_gcs_uri: str, + python_module_name: str, + container_uri: str, + model_serving_container_image_uri: Optional[str] = None, + model_serving_container_predict_route: Optional[str] = None, + model_serving_container_health_route: Optional[str] = None, + model_serving_container_command: Optional[Sequence[str]] = None, + model_serving_container_args: Optional[Sequence[str]] = None, + model_serving_container_environment_variables: Optional[Dict[str, str]] = None, + model_serving_container_ports: Optional[Sequence[int]] = None, + model_description: Optional[str] = None, + model_instance_schema_uri: Optional[str] = None, + model_parameters_schema_uri: Optional[str] = None, + model_prediction_schema_uri: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + staging_bucket: Optional[str] = None, + ): + """Constructs a Custom Training Job from a Python Package. + + job = aiplatform.CustomPythonPackageTrainingJob( + display_name='test-train', + python_package_gcs_uri='gs://my-bucket/my-python-package.tar.gz', + python_module_name='my-training-python-package.task', + container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest', + model_serving_container_image_uri='gcr.io/my-trainer/serving:1', + model_serving_container_predict_route='predict', + model_serving_container_health_route='metadata + ) + + Usage with Dataset: + + ds = aiplatform.TabularDataset( + 'projects/my-project/locations/us-central1/datasets/12345' + ) + + job.run( + ds, + replica_count=1, + model_display_name='my-trained-model' + ) + + Usage without Dataset: + + job.run( + replica_count=1, + model_display_name='my-trained-model' + ) + + To ensure your model gets saved in AI Platform, write your saved model to + os.environ["AIP_MODEL_DIR"] in your provided training script. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + python_package_gcs_uri (str): + Required: GCS location of the training python package. + python_module_name (str): + Required: The module name of the training python package. + container_uri (str): + Required: Uri of the training container image in the GCR. + model_serving_container_image_uri (str): + If the training produces a managed AI Platform Model, the URI of the + Model serving container suitable for serving the model produced by the + training script. + model_serving_container_predict_route (str): + If the training produces a managed AI Platform Model, An HTTP path to + send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by AI Platform. + model_serving_container_health_route (str): + If the training produces a managed AI Platform Model, an HTTP path to + send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI + Platform. + model_serving_container_command (Sequence[str]): + The command with which the container is run. Not executed within a + shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + model_serving_container_args (Sequence[str]): + The arguments to the command. The Docker image's CMD is used if this is + not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + model_serving_container_environment_variables (Dict[str, str]): + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + model_serving_container_ports (Sequence[int]): + Declaration of ports that are exposed by the container. This field is + primarily informational, it gives AI Platform information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + model_description (str): + The description of the Model. + model_instance_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + model_parameters_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + model_prediction_schema_uri (str): + Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + project (str): + Project to run training in. Overrides project set in aiplatform.init. + location (str): + Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + staging_bucket (str): + Bucket used to stage source and training artifacts. Overrides + staging_bucket set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + container_uri=container_uri, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_description=model_description, + staging_bucket=staging_bucket, + ) + + self._package_gcs_uri = python_package_gcs_uri + self._python_module = python_module_name + + def run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ] = None, + annotation_schema_uri: Optional[str] = None, + model_display_name: Optional[str] = None, + base_output_dir: Optional[str] = None, + bigquery_destination: Optional[str] = None, + args: Optional[List[Union[str, float, int]]] = None, + replica_count: int = 0, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Runs the custom training job. + + Distributed Training Support: + If replica count = 1 then one chief replica will be provisioned. If + replica_count > 1 the remainder will be provisioned as a worker replica pool. + ie: replica_count = 10 will result in 1 chief and 9 workers + All replicas have same machine_type, accelerator_type, and accelerator_count + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform.If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. Custom training script should + retrieve datasets through passed in environement variables uris: + + os.environ["AIP_TRAINING_DATA_URI"] + os.environ["AIP_VALIDATION_DATA_URI"] + os.environ["AIP_TEST_DATA_URI"] + + Additionally the dataset format is passed in as: + + os.environ["AIP_DATA_FORMAT"] + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + model_display_name (str): + If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + bigquery_destination (str): + Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + replica_count (int): + The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + machine_type (str): + The type of machine to use for training. + accelerator_type (str): + Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4, TPU_V2, TPU_V3 + accelerator_count (int): + The number of accelerators to attach to a worker replica. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + worker_pool_specs, managed_model = self._prepare_and_validate_run( + model_display_name=model_display_name, + replica_count=replica_count, + machine_type=machine_type, + accelerator_count=accelerator_count, + accelerator_type=accelerator_type, + ) + + return self._run( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + worker_pool_specs=worker_pool_specs, + managed_model=managed_model, + args=args, + base_output_dir=base_output_dir, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + bigquery_destination=bigquery_destination, + sync=sync, + ) + + @base.optional_sync(construct_object_on_arg="managed_model") + def _run( + self, + dataset: Optional[ + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ], + annotation_schema_uri: Optional[str], + worker_pool_specs: _DistributedTrainingSpec, + managed_model: Optional[gca_model.Model] = None, + args: Optional[List[Union[str, float, int]]] = None, + base_output_dir: Optional[str] = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + predefined_split_column_name: Optional[str] = None, + bigquery_destination: Optional[str] = None, + sync=True, + ) -> Optional[models.Model]: + """Packages local script and launches training_job. + + Args: + dataset ( + Union[ + datasets.ImageDataset, + datasets.TabularDataset, + datasets.TextDataset, + datasets.VideoDataset, + ] + ): + AI Platform to fit this training against. + annotation_schema_uri (str): + Google Cloud Storage URI points to a YAML file describing + annotation schema. + worker_pools_spec (_DistributedTrainingSpec): + Worker pools pecs required to run job. + managed_model (gca_model.Model): + Model proto if this script produces a Managed Model. + args (List[Unions[str, int, float]]): + Command line arguments to be passed to the Python script. + base_output_dir (str): + GCS output directory of job. If not provided a + timestamped directory in the staging directory will be used. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. + predefined_split_column_name (str): + Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular Datasets. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + for spec in worker_pool_specs: + spec["pythonPackageSpec"] = { + "executorImageUri": self._container_uri, + "pythonModule": self._python_module, + "packageUris": [self._package_gcs_uri], + } + + if args: + spec["pythonPackageSpec"]["args"] = args + + ( + training_task_inputs, + base_output_dir, + ) = self._prepare_training_task_inputs_and_output_dir( + worker_pool_specs, base_output_dir + ) + + model = self._run_job( + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=training_task_inputs, + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=predefined_split_column_name, + model=managed_model, + gcs_destination_uri_prefix=base_output_dir, + bigquery_destination=bigquery_destination, + ) + + return model + + +class AutoMLVideoTrainingJob(_TrainingJob): + + _supported_training_schemas = ( + schema.training_job.definition.automl_video_classification, + schema.training_job.definition.automl_video_object_tracking, + schema.training_job.definition.automl_video_action_recognition, + ) + + def __init__( + self, + display_name: str, + prediction_type: str = "classification", + model_type: str = "CLOUD", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Video Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + prediction_type (str): + The type of prediction the Model is to produce, one of: + "classification" - A video classification model classifies shots + and segments in your videos according to your own defined labels. + "object_tracking" - A video object tracking model detects and tracks + multiple objects in shots and segments. You can use these + models to track objects in your videos according to your + own pre-defined, custom labels. + "action_recognition" - A video action reconition model pinpoints + the location of actions with short temporal durations (~1 second). + model_type: str = "CLOUD" + Required. One of the following: + "CLOUD" - available for "classification", "object_tracking" and "action_recognition" + A Model best tailored to be used within Google Cloud, + and which cannot be exported. + "MOBILE_VERSATILE_1" - available for "classification", "object_tracking" and "action_recognition" + A model that, in addition to being available within Google + Cloud, can also be exported (see ModelService.ExportModel) + as a TensorFlow or TensorFlow Lite model and used on a + mobile or edge device with afterwards. + "MOBILE_CORAL_VERSATILE_1" - available only for "object_tracking" + A versatile model that is meant to be exported (see + ModelService.ExportModel) and used on a Google Coral device. + "MOBILE_CORAL_LOW_LATENCY_1" - available only for "object_tracking" + A model that trades off quality for low latency, to be + exported (see ModelService.ExportModel) and used on a + Google Coral device. + "MOBILE_JETSON_VERSATILE_1" - available only for "object_tracking" + A versatile model that is meant to be exported (see + ModelService.ExportModel) and used on an NVIDIA Jetson device. + "MOBILE_JETSON_LOW_LATENCY_1" - available only for "object_tracking" + A model that trades off quality for low latency, to be + exported (see ModelService.ExportModel) and used on an + NVIDIA Jetson device. + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + Raises: + ValueError: When an invalid prediction_type and/or model_type is provided. + """ + valid_model_types = constants.AUTOML_VIDEO_PREDICTION_MODEL_TYPES.get( + prediction_type, None + ) + + if not valid_model_types: + raise ValueError( + f"'{prediction_type}' is not a supported prediction type for AutoML Video Training. " + f"Please choose one of: {tuple(constants.AUTOML_VIDEO_PREDICTION_MODEL_TYPES.keys())}." + ) + + if model_type not in valid_model_types: + raise ValueError( + f"'{model_type}' is not a supported model_type for prediction_type of '{prediction_type}'. " + f"Please choose one of: {tuple(valid_model_types)}" + ) + + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + self._model_type = model_type + self._prediction_type = prediction_type + + def run( + self, + dataset: datasets.VideoDataset, + training_fraction_split: float = 0.8, + test_fraction_split: float = 0.2, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the AutoML Image training job and returns a model. + + Data fraction splits: + ``training_fraction_split``, and ``test_fraction_split`` may optionally + be provided, they must sum to up to 1. If none of the fractions are set, + by default roughly 80% of data will be used for training, and 20% for test. + + Args: + dataset (datasets.VideoDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split: float = 0.8 + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + test_fraction_split: float = 0.2 + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If not provided upon creation, the job's display_name is used. + sync: bool = True + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + + Raises: + RuntimeError: If Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Video Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Video Training has already run.") + + return self._run( + dataset=dataset, + training_fraction_split=training_fraction_split, + test_fraction_split=test_fraction_split, + model_display_name=model_display_name, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.VideoDataset, + training_fraction_split: float = 0.8, + test_fraction_split: float = 0.2, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, and ``test_fraction_split`` may optionally + be provided, they must sum to up to 1. If none of the fractions are set, + by default roughly 80% of data will be used for training, and 20% for test. + + Args: + dataset (datasets.VideoDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. The name + can be up to 128 characters long and can be consist of any UTF-8 + characters. If a `base_model` was provided, the display_name in the + base_model will be overritten with this value. If not provided upon + creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + # Retrieve the objective-specific training task schema based on prediction_type + training_task_definition = getattr( + schema.training_job.definition, f"automl_video_{self._prediction_type}" + ) + + training_task_inputs_dict = { + "modelType": self._model_type, + } + + # gca Model to be trained + model_tbt = gca_model.Model(encryption_spec=self._model_encryption_spec) + model_tbt.display_name = model_display_name or self._display_name + + return self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=0.0, + test_fraction_split=test_fraction_split, + model=model_tbt, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"AutoML Video Training Pipeline {self.resource_name} is not " + "configured to upload a Model." + ) + + +class AutoMLTextTrainingJob(_TrainingJob): + _supported_training_schemas = ( + schema.training_job.definition.automl_text_classification, + schema.training_job.definition.automl_text_extraction, + schema.training_job.definition.automl_text_sentiment, + ) + + def __init__( + self, + display_name: str, + prediction_type: str, + multi_label: bool = False, + sentiment_max: int = 10, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + training_encryption_spec_key_name: Optional[str] = None, + model_encryption_spec_key_name: Optional[str] = None, + ): + """Constructs a AutoML Text Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + prediction_type (str): + The type of prediction the Model is to produce, one of: + "classification" - A classification model analyzes text data and + returns a list of categories that apply to the text found in the data. + AI Platform offers both single-label and multi-label text classification models. + "extraction" - An entity extraction model inspects text data + for known entities referenced in the data and + labels those entities in the text. + "sentiment" - A sentiment analysis model inspects text data and identifies the + prevailing emotional opinion within it, especially to determine a writer's attitude + as positive, negative, or neutral. + multi_label (bool): + Required and only applicable for text classification task. If false, a single-label (multi-class) Model will be trained (i.e. + assuming that for each text snippet just up to one annotation may be + applicable). If true, a multi-label Model will be trained (i.e. + assuming that for each text snippet multiple annotations may be + applicable). + sentiment_max (int): + Required and only applicable for sentiment task. A sentiment is expressed as an integer + ordinal, where higher value means a more + positive sentiment. The range of sentiments that + will be used is between 0 and sentimentMax + (inclusive on both ends), and all the values in + the range must be represented in the dataset + before a model can be created. + Only the Annotations with this sentimentMax will + be used for training. sentimentMax value must be + between 1 and 10 (inclusive). + project (str): + Optional. Project to run training in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run training in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call training service. Overrides + credentials set in aiplatform.init. + training_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + + Overrides encryption_spec_key_name set in aiplatform.init. + model_encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + """ + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + ) + + training_task_definition: str + training_task_inputs_dict: proto.Message + + if prediction_type == "classification": + training_task_definition = ( + schema.training_job.definition.automl_text_classification + ) + + training_task_inputs_dict = training_job_inputs.AutoMlTextClassificationInputs( + multi_label=multi_label + ) + elif prediction_type == "extraction": + training_task_definition = ( + schema.training_job.definition.automl_text_extraction + ) + + training_task_inputs_dict = training_job_inputs.AutoMlTextExtractionInputs() + elif prediction_type == "sentiment": + training_task_definition = ( + schema.training_job.definition.automl_text_sentiment + ) + + training_task_inputs_dict = training_job_inputs.AutoMlTextSentimentInputs( + sentiment_max=sentiment_max + ) + else: + raise ValueError( + "Prediction type must be one of 'classification', 'extraction', or 'sentiment'." + ) + + self._training_task_definition = training_task_definition + self._training_task_inputs_dict = training_task_inputs_dict + + def run( + self, + dataset: datasets.TextDataset, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TextDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + training_fraction_split: float = 0.8 + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split: float = 0.1 + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. The display name of the managed AI Platform Model. + The name can be up to 128 characters long and can consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + Returns: + model: The trained AI Platform Model resource. + + Raises: + RuntimeError if Training job has already been run or is waiting to run. + """ + + if self._is_waiting_to_run(): + raise RuntimeError("AutoML Text Training is already scheduled to run.") + + if self._has_run: + raise RuntimeError("AutoML Text Training has already run.") + + return self._run( + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + model_display_name=model_display_name, + sync=sync, + ) + + @base.optional_sync() + def _run( + self, + dataset: datasets.TextDataset, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + model_display_name: Optional[str] = None, + sync: bool = True, + ) -> models.Model: + """Runs the training job and returns a model. + + Data fraction splits: + Any of ``training_fraction_split``, ``validation_fraction_split`` and + ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If + the provided ones sum to less than 1, the remainder is assigned to sets as + decided by AI Platform. If none of the fractions are set, by default roughly 80% + of data will be used for training, 10% for validation, and 10% for test. + + Args: + dataset (datasets.TextDataset): + Required. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For Text Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + Required. The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + Required. The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + Required. The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model_display_name (str): + Optional. If the script produces a managed AI Platform Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + + Returns: + model: The trained AI Platform Model resource or None if training did not + produce an AI Platform Model. + """ + + if model_display_name is None: + model_display_name = self._display_name + + model = gca_model.Model( + display_name=model_display_name, + encryption_spec=self._model_encryption_spec, + ) + + return self._run_job( + training_task_definition=self._training_task_definition, + training_task_inputs=self._training_task_inputs_dict, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + predefined_split_column_name=None, + model=model, + ) + + @property + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + return ( + f"AutoML Text Training Pipeline {self.resource_name} is not " + "configured to upload a Model." + ) diff --git a/google/cloud/aiplatform/training_utils.py b/google/cloud/aiplatform/training_utils.py new file mode 100644 index 0000000000..a93ecaa1ce --- /dev/null +++ b/google/cloud/aiplatform/training_utils.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os + +from typing import Dict, Optional + + +class EnvironmentVariables: + """Passes on OS' environment variables""" + + @property + def training_data_uri(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for training data. None if + environment variable not set. + """ + return os.environ.get("AIP_TRAINING_DATA_URI") + + @property + def validation_data_uri(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for validation data. None + if environment variable not set. + """ + return os.environ.get("AIP_VALIDATION_DATA_URI") + + @property + def test_data_uri(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for test data. None if + environment variable not set. + """ + return os.environ.get("AIP_TEST_DATA_URI") + + @property + def model_dir(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for saving model artefacts. + None if environment variable not set. + """ + return os.environ.get("AIP_MODEL_DIR") + + @property + def checkpoint_dir(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for saving checkpoints. + None if environment variable not set. + """ + return os.environ.get("AIP_CHECKPOINT_DIR") + + @property + def tensorboard_log_dir(self) -> Optional[str]: + """ + Returns: + Cloud Storage URI of a directory intended for saving TensorBoard logs. + None if environment variable not set. + """ + return os.environ.get("AIP_TENSORBOARD_LOG_DIR") + + @property + def cluster_spec(self) -> Optional[Dict]: + """ + Returns: + json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#cluster-variables + None if environment variable not set. + """ + cluster_spec_env = os.environ.get("CLUSTER_SPEC") + if cluster_spec_env is not None: + return json.loads(cluster_spec_env) + else: + return None + + @property + def tf_config(self) -> Optional[Dict]: + """ + Returns: + json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#tf-config + None if environment variable not set. + """ + tf_config_env = os.environ.get("TF_CONFIG") + if tf_config_env is not None: + return json.loads(tf_config_env) + else: + return None diff --git a/google/cloud/aiplatform/utils.py b/google/cloud/aiplatform/utils.py new file mode 100644 index 0000000000..7584c7d02e --- /dev/null +++ b/google/cloud/aiplatform/utils.py @@ -0,0 +1,469 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import abc +from collections import namedtuple +import logging +import re +from typing import Any, Match, Optional, Type, TypeVar, Tuple + +from google.api_core import client_options +from google.api_core import gapic_v1 +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform.compat.services import ( + dataset_service_client_v1beta1, + endpoint_service_client_v1beta1, + job_service_client_v1beta1, + model_service_client_v1beta1, + pipeline_service_client_v1beta1, + prediction_service_client_v1beta1, +) +from google.cloud.aiplatform.compat.services import ( + dataset_service_client_v1, + endpoint_service_client_v1, + job_service_client_v1, + model_service_client_v1, + pipeline_service_client_v1, + prediction_service_client_v1, +) + +from google.cloud.aiplatform.compat.types import ( + accelerator_type as gca_accelerator_type, +) + +AiPlatformServiceClient = TypeVar( + "AiPlatformServiceClient", + # v1beta1 + dataset_service_client_v1beta1.DatasetServiceClient, + endpoint_service_client_v1beta1.EndpointServiceClient, + model_service_client_v1beta1.ModelServiceClient, + prediction_service_client_v1beta1.PredictionServiceClient, + pipeline_service_client_v1beta1.PipelineServiceClient, + job_service_client_v1beta1.JobServiceClient, + # v1 + dataset_service_client_v1.DatasetServiceClient, + endpoint_service_client_v1.EndpointServiceClient, + model_service_client_v1.ModelServiceClient, + prediction_service_client_v1.PredictionServiceClient, + pipeline_service_client_v1.PipelineServiceClient, + job_service_client_v1.JobServiceClient, +) + +# TODO(b/170334193): Add support for resource names with non-integer IDs +# TODO(b/170334098): Add support for resource names more than one level deep +RESOURCE_NAME_PATTERN = re.compile( + r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P\w+)\/(?P\d+)$" +) +RESOURCE_ID_PATTERN = re.compile(r"^\d+$") + +Fields = namedtuple("Fields", ["project", "location", "resource", "id"],) + + +def _match_to_fields(match: Match) -> Optional[Fields]: + """Normalize RegEx groups from resource name pattern Match to class Fields""" + if not match: + return None + + return Fields( + project=match["project"], + location=match["location"], + resource=match["resource"], + id=match["id"], + ) + + +def validate_id(resource_id: str) -> bool: + """Validate int64 resource ID number""" + return bool(RESOURCE_ID_PATTERN.match(resource_id)) + + +def extract_fields_from_resource_name( + resource_name: str, resource_noun: Optional[str] = None +) -> Optional[Fields]: + """Validates and returns extracted fields from a fully-qualified resource name. + Returns None if name is invalid. + + Args: + resource_name (str): + Required. A fully-qualified AI Platform (Unified) resource name + + resource_noun (str): + A plural resource noun to validate the resource name against. + For example, you would pass "datasets" to validate + "projects/123/locations/us-central1/datasets/456". + + Returns: + fields (Fields): + A named tuple containing four extracted fields from a resource name: + project, location, resource, and id. These fields can be used for + subsequent method calls in the SDK. + """ + fields = _match_to_fields(RESOURCE_NAME_PATTERN.match(resource_name)) + + if not fields: + return None + if resource_noun and fields.resource != resource_noun: + return None + + return fields + + +def full_resource_name( + resource_name: str, + resource_noun: str, + project: Optional[str] = None, + location: Optional[str] = None, +) -> str: + """ + Returns fully qualified resource name. + + Args: + resource_name (str): + Required. A fully-qualified AI Platform (Unified) resource name or + resource ID. + resource_noun (str): + A plural resource noun to validate the resource name against. + For example, you would pass "datasets" to validate + "projects/123/locations/us-central1/datasets/456". + project (str): + Optional project to retrieve resource_noun from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional location to retrieve resource_noun from. If not set, location + set in aiplatform.init will be used. + + Returns: + resource_name (str): + A fully-qualified AI Platform (Unified) resource name. + + Raises: + ValueError: + If resource name, resource ID or project ID not provided. + """ + validate_resource_noun(resource_noun) + # Fully qualified resource name, i.e. "projects/.../locations/.../datasets/12345" + valid_name = extract_fields_from_resource_name( + resource_name=resource_name, resource_noun=resource_noun + ) + + user_project = project or initializer.global_config.project + user_location = location or initializer.global_config.location + + # Partial resource name (i.e. "12345") with known project and location + if ( + not valid_name + and validate_project(user_project) + and validate_region(user_location) + and validate_id(resource_name) + ): + resource_name = f"projects/{user_project}/locations/{user_location}/{resource_noun}/{resource_name}" + # Invalid resource_name parameter + elif not valid_name: + raise ValueError(f"Please provide a valid {resource_noun[:-1]} name or ID") + + return resource_name + + +# TODO(b/172286889) validate resource noun +def validate_resource_noun(resource_noun: str) -> bool: + """Validates resource noun. + + Args: + resource_noun: resource noun to validate + Returns: + bool: True if no errors raised + Raises: + ValueError: If resource noun not supported. + """ + if resource_noun: + return True + raise ValueError("Please provide a valid resource noun") + + +# TODO(b/172288287) validate project +def validate_project(project: str) -> bool: + """Validates project. + + Args: + project: project to validate + Returns: + bool: True if no errors raised + Raises: + ValueError: If project does not exist. + """ + if project: + return True + raise ValueError("Please provide a valid project ID") + + +# TODO(b/172932277) verify display name only contains utf-8 chars +def validate_display_name(display_name: str): + """Verify display name is at most 128 chars + + Args: + display_name: display name to verify + Raises: + ValueError: display name is longer than 128 characters + """ + if len(display_name) > 128: + raise ValueError("Display name needs to be less than 128 characters.") + + +def validate_region(region: str) -> bool: + """Validates region against supported regions. + + Args: + region: region to validate + Returns: + bool: True if no errors raised + Raises: + ValueError: If region is not in supported regions. + """ + if not region: + raise ValueError( + f"Please provide a region, select from {constants.SUPPORTED_REGIONS}" + ) + + region = region.lower() + if region not in constants.SUPPORTED_REGIONS: + raise ValueError( + f"Unsupported region for AI Platform, select from {constants.SUPPORTED_REGIONS}" + ) + + return True + + +def validate_accelerator_type(accelerator_type: str) -> bool: + """Validates user provided accelerator_type string for training and prediction + + Args: + accelerator_type (str): + Represents a hardware accelerator type. + Returns: + bool: True if valid accelerator_type + Raises: + ValueError if accelerator type is invalid. + """ + if accelerator_type not in gca_accelerator_type.AcceleratorType._member_names_: + raise ValueError( + f"Given accelerator_type `{accelerator_type}` invalid. " + f"Choose one of {gca_accelerator_type.AcceleratorType._member_names_}" + ) + return True + + +def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optional[str]]: + """Given a complete GCS path, return the bucket name and prefix as a tuple. + + Example Usage: + + bucket, prefix = extract_bucket_and_prefix_from_gcs_path( + "gs://example-bucket/path/to/folder" + ) + + # bucket = "example-bucket" + # prefix = "path/to/folder" + + Args: + gcs_path (str): + Required. A full path to a Google Cloud Storage folder or resource. + Can optionally include "gs://" prefix or end in a trailing slash "/". + + Returns: + Tuple[str, Optional[str]] + A (bucket, prefix) pair from provided GCS path. If a prefix is not + present, a None will be returned in its place. + """ + if gcs_path.startswith("gs://"): + gcs_path = gcs_path[5:] + if gcs_path.endswith("/"): + gcs_path = gcs_path[:-1] + + gcs_parts = gcs_path.split("/", 1) + gcs_bucket = gcs_parts[0] + gcs_blob_prefix = None if len(gcs_parts) == 1 else gcs_parts[1] + + return (gcs_bucket, gcs_blob_prefix) + + +class ClientWithOverride: + class WrappedClient: + """Wrapper class for client that creates client at API invocation time.""" + + def __init__( + self, + client_class: Type[AiPlatformServiceClient], + client_options: client_options.ClientOptions, + client_info: gapic_v1.client_info.ClientInfo, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Stores parameters needed to instantiate client. + + client_class (AiPlatformServiceClient): + Required. Class of the client to use. + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. + """ + + self._client_class = client_class + self._credentials = credentials + self._client_options = client_options + self._client_info = client_info + + def __getattr__(self, name: str) -> Any: + """Instantiates client and returns attribute of the client.""" + temporary_client = self._client_class( + credentials=self._credentials, + client_options=self._client_options, + client_info=self._client_info, + ) + return getattr(temporary_client, name) + + @property + @abc.abstractmethod + def _is_temporary(self) -> bool: + pass + + @property + @classmethod + @abc.abstractmethod + def _default_version(self) -> str: + pass + + @property + @classmethod + @abc.abstractmethod + def _version_map(self) -> Tuple: + pass + + def __init__( + self, + client_options: client_options.ClientOptions, + client_info: gapic_v1.client_info.ClientInfo, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Stores parameters needed to instantiate client. + + client_options (client_options.ClientOptions): + Required. Client options to pass to client. + client_info (gapic_v1.client_info.ClientInfo): + Required. Client info to pass to client. + credentials (auth_credentials.credentials): + Optional. Client credentials to pass to client. + """ + + self._clients = { + version: self.WrappedClient( + client_class=client_class, + client_options=client_options, + client_info=client_info, + credentials=credentials, + ) + if self._is_temporary + else client_class( + client_options=client_options, + client_info=client_info, + credentials=credentials, + ) + for version, client_class in self._version_map + } + + def __getattr__(self, name: str) -> Any: + """Instantiates client and returns attribute of the client.""" + return getattr(self._clients[self._default_version], name) + + def select_version(self, version: str) -> AiPlatformServiceClient: + return self._clients[version] + + +class DatasetClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, dataset_service_client_v1.DatasetServiceClient), + (compat.V1BETA1, dataset_service_client_v1beta1.DatasetServiceClient), + ) + + +class EndpointClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, endpoint_service_client_v1.EndpointServiceClient), + (compat.V1BETA1, endpoint_service_client_v1beta1.EndpointServiceClient), + ) + + +class JobpointClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, job_service_client_v1.JobServiceClient), + (compat.V1BETA1, job_service_client_v1beta1.JobServiceClient), + ) + + +class ModelClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, model_service_client_v1.ModelServiceClient), + (compat.V1BETA1, model_service_client_v1beta1.ModelServiceClient), + ) + + +class PipelineClientWithOverride(ClientWithOverride): + _is_temporary = True + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, pipeline_service_client_v1.PipelineServiceClient), + (compat.V1BETA1, pipeline_service_client_v1beta1.PipelineServiceClient), + ) + + +class PredictionClientWithOverride(ClientWithOverride): + _is_temporary = False + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, prediction_service_client_v1.PredictionServiceClient), + (compat.V1BETA1, prediction_service_client_v1beta1.PredictionServiceClient), + ) + + +AiPlatformServiceClientWithOverride = TypeVar( + "AiPlatformServiceClientWithOverride", + DatasetClientWithOverride, + EndpointClientWithOverride, + JobpointClientWithOverride, + ModelClientWithOverride, + PipelineClientWithOverride, + PredictionClientWithOverride, +) + + +class LoggingWarningFilter(logging.Filter): + def filter(self, record): + return record.levelname == logging.WARNING diff --git a/noxfile.py b/noxfile.py index 35270f664f..2cb95f3d6d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -204,9 +204,7 @@ def docfx(session): """Build the docfx yaml files for this library.""" session.install("-e", ".") - # sphinx-docfx-yaml supports up to sphinx version 1.5.5. - # https://github.com/docascode/sphinx-docfx-yaml/issues/97 - session.install("sphinx==1.5.5", "alabaster", "recommonmark", "sphinx-docfx-yaml") + session.install("sphinx", "alabaster", "recommonmark", "gcp-sphinx-docfx-yaml") shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py new file mode 100644 index 0000000000..580c6a962d --- /dev/null +++ b/samples/model-builder/conftest.py @@ -0,0 +1,205 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +from google.cloud import aiplatform +import pytest + + +@pytest.fixture +def mock_sdk_init(): + with patch.object(aiplatform, "init") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Dataset Fixtures +---------------------------------------------------------------------------- +""" + +"""Dataset objects returned by SomeDataset(), create(), import_data(), etc. """ + + +@pytest.fixture +def mock_image_dataset(): + mock = MagicMock(aiplatform.datasets.ImageDataset) + yield mock + + +@pytest.fixture +def mock_tabular_dataset(): + mock = MagicMock(aiplatform.datasets.TabularDataset) + yield mock + + +@pytest.fixture +def mock_text_dataset(): + mock = MagicMock(aiplatform.datasets.TextDataset) + yield mock + + +@pytest.fixture +def mock_video_dataset(): + mock = MagicMock(aiplatform.datasets.VideoDataset) + yield mock + + +"""Mocks for getting an existing Dataset, i.e. ds = aiplatform.ImageDataset(...) """ + + +@pytest.fixture +def mock_get_image_dataset(mock_image_dataset): + with patch.object(aiplatform, "ImageDataset") as mock_get_image_dataset: + mock_get_image_dataset.return_value = mock_image_dataset + yield mock_get_image_dataset + + +@pytest.fixture +def mock_get_tabular_dataset(mock_tabular_dataset): + with patch.object(aiplatform, "TabularDataset") as mock_get_tabular_dataset: + mock_get_tabular_dataset.return_value = mock_tabular_dataset + yield mock_get_tabular_dataset + + +@pytest.fixture +def mock_get_text_dataset(mock_text_dataset): + with patch.object(aiplatform, "TextDataset") as mock_get_text_dataset: + mock_get_text_dataset.return_value = mock_text_dataset + yield mock_get_text_dataset + + +@pytest.fixture +def mock_get_video_dataset(mock_video_dataset): + with patch.object(aiplatform, "VideoDataset") as mock_get_video_dataset: + mock_get_video_dataset.return_value = mock_video_dataset + yield mock_get_video_dataset + + +"""Mocks for creating a new Dataset, i.e. aiplatform.ImageDataset.create(...) """ + + +@pytest.fixture +def mock_create_image_dataset(mock_image_dataset): + with patch.object(aiplatform.ImageDataset, "create") as mock_create_image_dataset: + mock_create_image_dataset.return_value = mock_image_dataset + yield mock_create_image_dataset + + +@pytest.fixture +def mock_create_tabular_dataset(mock_tabular_dataset): + with patch.object( + aiplatform.TabularDataset, "create" + ) as mock_create_tabular_dataset: + mock_create_tabular_dataset.return_value = mock_tabular_dataset + yield mock_create_tabular_dataset + + +@pytest.fixture +def mock_create_text_dataset(mock_text_dataset): + with patch.object(aiplatform.TextDataset, "create") as mock_create_text_dataset: + mock_create_text_dataset.return_value = mock_text_dataset + yield mock_create_text_dataset + + +@pytest.fixture +def mock_create_video_dataset(mock_video_dataset): + with patch.object(aiplatform.VideoDataset, "create") as mock_create_video_dataset: + mock_create_video_dataset.return_value = mock_video_dataset + yield mock_create_video_dataset + + +"""Mocks for SomeDataset.import_data() """ + + +@pytest.fixture +def mock_import_text_dataset(mock_text_dataset): + with patch.object(mock_text_dataset, "import_data") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +TrainingJob Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_init_automl_image_training_job(): + with patch.object( + aiplatform.training_jobs.AutoMLImageTrainingJob, "__init__" + ) as mock: + mock.return_value = None + yield mock + + +@pytest.fixture +def mock_run_automl_image_training_job(): + with patch.object(aiplatform.training_jobs.AutoMLImageTrainingJob, "run") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Model Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_init_model(): + with patch.object(aiplatform.models.Model, "__init__") as mock: + mock.return_value = None + yield mock + + +@pytest.fixture +def mock_batch_predict_model(): + with patch.object(aiplatform.models.Model, "batch_predict") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Job Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_create_batch_prediction_job(): + with patch.object(aiplatform.jobs.BatchPredictionJob, "create") as mock: + yield mock + + +""" +---------------------------------------------------------------------------- +Endpoint Fixtures +---------------------------------------------------------------------------- +""" + + +@pytest.fixture +def mock_endpoint(): + mock = MagicMock(aiplatform.models.Endpoint) + yield mock + + +@pytest.fixture +def mock_get_endpoint(mock_endpoint): + with patch.object(aiplatform, "Endpoint") as mock_get_endpoint: + mock_get_endpoint.return_value = mock_endpoint + yield mock_get_endpoint diff --git a/samples/model-builder/create_and_import_dataset_image_sample.py b/samples/model-builder/create_and_import_dataset_image_sample.py new file mode 100644 index 0000000000..bab7c8a59c --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_image_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_image_sample] +def create_and_import_dataset_image_sample( + project: str, + location: str, + display_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.ImageDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.image.single_label_classification, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_create_and_import_dataset_image_sample] diff --git a/samples/model-builder/create_and_import_dataset_image_sample_test.py b/samples/model-builder/create_and_import_dataset_image_sample_test.py new file mode 100644 index 0000000000..6991ff3a13 --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_image_sample_test.py @@ -0,0 +1,41 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import create_and_import_dataset_image_sample +import test_constants as constants + + +def test_create_and_import_dataset_image_sample( + mock_sdk_init, mock_create_image_dataset +): + + create_and_import_dataset_image_sample.create_and_import_dataset_image_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_image_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.image.single_label_classification, + sync=True, + ) diff --git a/samples/model-builder/create_and_import_dataset_text_sample.py b/samples/model-builder/create_and_import_dataset_text_sample.py new file mode 100644 index 0000000000..e3321020bf --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_text_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_and_import_dataset_text_sample] +def create_and_import_dataset_text_sample( + project: str, + location: str, + display_name: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset.create( + display_name=display_name, + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.single_label_classification, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_create_and_import_dataset_text_sample] diff --git a/samples/model-builder/create_and_import_dataset_text_sample_test.py b/samples/model-builder/create_and_import_dataset_text_sample_test.py new file mode 100644 index 0000000000..e41082d06f --- /dev/null +++ b/samples/model-builder/create_and_import_dataset_text_sample_test.py @@ -0,0 +1,39 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import create_and_import_dataset_text_sample +import test_constants as constants + + +def test_create_and_import_dataset_text_sample(mock_sdk_init, mock_create_text_dataset): + + create_and_import_dataset_text_sample.create_and_import_dataset_text_sample( + project=constants.PROJECT, + location=constants.LOCATION, + src_uris=constants.GCS_SOURCES, + display_name=constants.DISPLAY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_create_text_dataset.assert_called_once_with( + display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.single_label_classification, + sync=True, + ) diff --git a/samples/model-builder/create_batch_prediction_job_sample.py b/samples/model-builder/create_batch_prediction_job_sample.py new file mode 100644 index 0000000000..9bd5c697a5 --- /dev/null +++ b/samples/model-builder/create_batch_prediction_job_sample.py @@ -0,0 +1,49 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_batch_prediction_job_sample] +def create_batch_prediction_job_sample( + project: str, + location: str, + model_resource_name: str, + job_display_name: str, + gcs_source: Union[str, Sequence[str]], + gcs_destination: str, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + my_model = aiplatform.Model(model_resource_name) + + batch_prediction_job = my_model.batch_predict( + job_display_name=job_display_name, + gcs_source=gcs_source, + gcs_destination_prefix=gcs_destination, + sync=sync, + ) + + batch_prediction_job.wait() + + print(batch_prediction_job.display_name) + print(batch_prediction_job.resource_name) + print(batch_prediction_job.state) + return batch_prediction_job + + +# [END aiplatform_sdk_create_batch_prediction_job_sample] diff --git a/samples/model-builder/create_batch_prediction_job_sample_test.py b/samples/model-builder/create_batch_prediction_job_sample_test.py new file mode 100644 index 0000000000..f39c1020b5 --- /dev/null +++ b/samples/model-builder/create_batch_prediction_job_sample_test.py @@ -0,0 +1,42 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_batch_prediction_job_sample +import test_constants as constants + + +def test_create_batch_prediction_job_sample( + mock_sdk_init, mock_init_model, mock_batch_predict_model +): + + create_batch_prediction_job_sample.create_batch_prediction_job_sample( + project=constants.PROJECT, + location=constants.LOCATION, + model_resource_name=constants.MODEL_NAME, + job_display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + gcs_destination=constants.GCS_DESTINATION, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_init_model.assert_called_once_with(constants.MODEL_NAME) + mock_batch_predict_model.assert_called_once_with( + job_display_name=constants.DISPLAY_NAME, + gcs_source=constants.GCS_SOURCES, + gcs_destination_prefix=constants.GCS_DESTINATION, + sync=True, + ) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample.py b/samples/model-builder/create_training_pipeline_image_classification_sample.py new file mode 100644 index 0000000000..050d40af82 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_image_classification_sample.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_create_training_pipeline_image_classification_sample] +def create_training_pipeline_image_classification_sample( + project: str, + display_name: str, + dataset_id: int, + location: str = "us-central1", + model_display_name: str = None, + training_fraction_split: float = 0.8, + validation_fraction_split: float = 0.1, + test_fraction_split: float = 0.1, + budget_milli_node_hours: int = 8000, + disable_early_stopping: bool = False, + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + job = aiplatform.AutoMLImageTrainingJob(display_name=display_name) + + my_image_ds = aiplatform.ImageDataset(dataset_id) + + model = job.run( + dataset=my_image_ds, + model_display_name=model_display_name, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + budget_milli_node_hours=budget_milli_node_hours, + disable_early_stopping=disable_early_stopping, + sync=sync, + ) + + model.wait() + + print(model.display_name) + print(model.resource_name) + print(model.uri) + return model + + +# [END aiplatform_sdk_create_training_pipeline_image_classification_sample] diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py new file mode 100644 index 0000000000..c49e0e5f05 --- /dev/null +++ b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import create_training_pipeline_image_classification_sample +import test_constants as constants + + +def test_create_training_pipeline_image_classification_sample( + mock_sdk_init, + mock_image_dataset, + mock_init_automl_image_training_job, + mock_run_automl_image_training_job, + mock_get_image_dataset, +): + + create_training_pipeline_image_classification_sample.create_training_pipeline_image_classification_sample( + project=constants.PROJECT, + display_name=constants.DISPLAY_NAME, + dataset_id=constants.RESOURCE_ID, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + ) + + mock_get_image_dataset.assert_called_once_with(constants.RESOURCE_ID) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + mock_init_automl_image_training_job.assert_called_once_with( + display_name=constants.DISPLAY_NAME + ) + mock_run_automl_image_training_job.assert_called_once_with( + dataset=mock_image_dataset, + model_display_name=constants.DISPLAY_NAME_2, + training_fraction_split=constants.TRAINING_FRACTION_SPLIT, + validation_fraction_split=constants.VALIDATION_FRACTION_SPLIT, + test_fraction_split=constants.TEST_FRACTION_SPLIT, + budget_milli_node_hours=constants.BUDGET_MILLI_NODE_HOURS_8000, + disable_early_stopping=False, + sync=True, + ) diff --git a/samples/model-builder/import_data_text_classification_single_label_sample.py b/samples/model-builder/import_data_text_classification_single_label_sample.py new file mode 100644 index 0000000000..c63cc3f1d1 --- /dev/null +++ b/samples/model-builder/import_data_text_classification_single_label_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_text_classification_single_label_sample] +def import_data_text_classification_single_label( + project: str, + location: str, + dataset: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset(dataset) + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.single_label_classification, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_text_classification_single_label_sample] diff --git a/samples/model-builder/import_data_text_classification_single_label_sample_test.py b/samples/model-builder/import_data_text_classification_single_label_sample_test.py new file mode 100644 index 0000000000..1765ab013e --- /dev/null +++ b/samples/model-builder/import_data_text_classification_single_label_sample_test.py @@ -0,0 +1,43 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import import_data_text_classification_single_label_sample +import test_constants as constants + + +def test_import_data_text_classification_single_label_sample( + mock_sdk_init, mock_get_text_dataset, mock_import_text_dataset +): + + import_data_text_classification_single_label_sample.import_data_text_classification_single_label( + project=constants.PROJECT, + location=constants.LOCATION, + dataset=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_text_dataset.assert_called_once_with(constants.DATASET_NAME) + + mock_import_text_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.single_label_classification, + sync=True, + ) diff --git a/samples/model-builder/import_data_text_entity_extraction_sample.py b/samples/model-builder/import_data_text_entity_extraction_sample.py new file mode 100644 index 0000000000..7e00d57632 --- /dev/null +++ b/samples/model-builder/import_data_text_entity_extraction_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_text_entity_extraction_sample] +def import_data_text_entity_extraction_sample( + project: str, + location: str, + dataset: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset(dataset) + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.extraction, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_text_entity_extraction_sample] diff --git a/samples/model-builder/import_data_text_entity_extraction_sample_test.py b/samples/model-builder/import_data_text_entity_extraction_sample_test.py new file mode 100644 index 0000000000..a3b93e9200 --- /dev/null +++ b/samples/model-builder/import_data_text_entity_extraction_sample_test.py @@ -0,0 +1,45 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import import_data_text_entity_extraction_sample +import test_constants as constants + + +def test_import_data_text_entity_extraction_sample( + mock_sdk_init, mock_get_text_dataset, mock_import_text_dataset +): + + import_data_text_entity_extraction_sample.import_data_text_entity_extraction_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_text_dataset.assert_called_once_with( + constants.DATASET_NAME, + ) + + mock_import_text_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.extraction, + sync=True, + ) diff --git a/samples/model-builder/import_data_text_sentiment_analysis_sample.py b/samples/model-builder/import_data_text_sentiment_analysis_sample.py new file mode 100644 index 0000000000..3861a1102a --- /dev/null +++ b/samples/model-builder/import_data_text_sentiment_analysis_sample.py @@ -0,0 +1,44 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_import_data_text_sentiment_analysis_sample] +def import_data_text_sentiment_analysis_sample( + project: str, + location: str, + dataset: str, + src_uris: Union[str, List[str]], + sync: bool = True, +): + aiplatform.init(project=project, location=location) + + ds = aiplatform.TextDataset(dataset) + ds.import_data( + gcs_source=src_uris, + import_schema_uri=aiplatform.schema.dataset.ioformat.text.sentiment, + sync=sync, + ) + + ds.wait() + + print(ds.display_name) + print(ds.resource_name) + return ds + + +# [END aiplatform_sdk_import_data_text_sentiment_analysis_sample] diff --git a/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py b/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py new file mode 100644 index 0000000000..2134d66b35 --- /dev/null +++ b/samples/model-builder/import_data_text_sentiment_analysis_sample_test.py @@ -0,0 +1,45 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud.aiplatform import schema + +import import_data_text_sentiment_analysis_sample +import test_constants as constants + + +def test_import_data_text_sentiment_analysis_sample( + mock_sdk_init, mock_get_text_dataset, mock_import_text_dataset +): + + import_data_text_sentiment_analysis_sample.import_data_text_sentiment_analysis_sample( + project=constants.PROJECT, + location=constants.LOCATION, + dataset=constants.DATASET_NAME, + src_uris=constants.GCS_SOURCES, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_text_dataset.assert_called_once_with( + constants.DATASET_NAME, + ) + + mock_import_text_dataset.assert_called_once_with( + gcs_source=constants.GCS_SOURCES, + import_schema_uri=schema.dataset.ioformat.text.sentiment, + sync=True, + ) diff --git a/samples/model-builder/init_sample.py b/samples/model-builder/init_sample.py new file mode 100644 index 0000000000..8ced169ec4 --- /dev/null +++ b/samples/model-builder/init_sample.py @@ -0,0 +1,40 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from google.auth import credentials as auth_credentials +from google.cloud import aiplatform + + +# [START aiplatform_sdk_init_sample] +def init_sample( + project: Optional[str] = None, + location: Optional[str] = None, + experiment: Optional[str] = None, + staging_bucket: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, +): + aiplatform.init( + project=project, + location=location, + experiment=experiment, + staging_bucket=staging_bucket, + credentials=credentials, + encryption_spec_key_name=encryption_spec_key_name, + ) + + +# [END aiplatform_sdk_init_sample] diff --git a/samples/model-builder/init_sample_test.py b/samples/model-builder/init_sample_test.py new file mode 100644 index 0000000000..3c4684a255 --- /dev/null +++ b/samples/model-builder/init_sample_test.py @@ -0,0 +1,38 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import init_sample +import test_constants as constants + + +def test_init_sample(mock_sdk_init): + + init_sample.init_sample( + project=constants.PROJECT, + location=constants.LOCATION_EUROPE, + experiment=constants.EXPERIMENT_NAME, + staging_bucket=constants.STAGING_BUCKET, + credentials=constants.CREDENTIALS, + encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, + location=constants.LOCATION_EUROPE, + experiment=constants.EXPERIMENT_NAME, + staging_bucket=constants.STAGING_BUCKET, + credentials=constants.CREDENTIALS, + encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME, + ) diff --git a/samples/model-builder/noxfile.py b/samples/model-builder/noxfile.py new file mode 100644 index 0000000000..83bf446de2 --- /dev/null +++ b/samples/model-builder/noxfile.py @@ -0,0 +1,221 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +from pathlib import Path +import sys + +import nox + +# WARNING - WARNING - WARNING - WARNING - WARNING +# WARNING - WARNING - WARNING - WARNING - WARNING +# DO NOT EDIT THIS FILE EVER! +# WARNING - WARNING - WARNING - WARNING - WARNING +# WARNING - WARNING - WARNING - WARNING - WARNING + +# Copy `noxfile_config.py` to your directory and modify it instead. + + +# `TEST_CONFIG` dict is a configuration hook that allows users to +# modify the test configurations. The values here should be in sync +# with `noxfile_config.py`. Users will copy `noxfile_config.py` into +# their directory and modify it. + +TEST_CONFIG = { + # You can opt out from the test for specific Python versions. + "ignored_versions": ["2.7"], + # An envvar key for determining the project id to use. Change it + # to 'BUILD_SPECIFIC_GCLOUD_PROJECT' if you want to opt in using a + # build specific Cloud project. You can also use your own string + # to use your own Cloud project. + "gcloud_project_env": "GOOGLE_CLOUD_PROJECT", + # 'gcloud_project_env': 'BUILD_SPECIFIC_GCLOUD_PROJECT', + # A dictionary you want to inject into your test. Don't put any + # secrets here. These values will override predefined values. + "envs": {}, +} + + +try: + # Ensure we can import noxfile_config in the project's directory. + sys.path.append(".") + from noxfile_config import TEST_CONFIG_OVERRIDE +except ImportError as e: + print("No user noxfile_config found: detail: {}".format(e)) + TEST_CONFIG_OVERRIDE = {} + +# Update the TEST_CONFIG with the user supplied values. +TEST_CONFIG.update(TEST_CONFIG_OVERRIDE) + + +def get_pytest_env_vars(): + """Returns a dict for pytest invocation.""" + ret = {} + + # Override the GCLOUD_PROJECT and the alias. + env_key = TEST_CONFIG["gcloud_project_env"] + # This should error out if not set. + ret["GOOGLE_CLOUD_PROJECT"] = os.environ[env_key] + + # Apply user supplied envs. + ret.update(TEST_CONFIG["envs"]) + return ret + + +# DO NOT EDIT - automatically generated. +# All versions used to tested samples. +ALL_VERSIONS = ["2.7", "3.6", "3.7", "3.8"] + +# Any default versions that should be ignored. +IGNORED_VERSIONS = TEST_CONFIG["ignored_versions"] + +TESTED_VERSIONS = sorted([v for v in ALL_VERSIONS if v not in IGNORED_VERSIONS]) + +INSTALL_LIBRARY_FROM_SOURCE = bool(os.environ.get("INSTALL_LIBRARY_FROM_SOURCE", False)) +# +# Style Checks +# + + +def _determine_local_import_names(start_dir): + """Determines all import names that should be considered "local". + + This is used when running the linter to insure that import order is + properly checked. + """ + file_ext_pairs = [os.path.splitext(path) for path in os.listdir(start_dir)] + return [ + basename + for basename, extension in file_ext_pairs + if extension == ".py" + or os.path.isdir(os.path.join(start_dir, basename)) + and basename not in ("__pycache__") + ] + + +# Linting with flake8. +# +# We ignore the following rules: +# E203: whitespace before ‘:’ +# E266: too many leading ‘#’ for block comment +# E501: line too long +# I202: Additional newline in a section of imports +# +# We also need to specify the rules which are ignored by default: +# ['E226', 'W504', 'E126', 'E123', 'W503', 'E24', 'E704', 'E121'] +FLAKE8_COMMON_ARGS = [ + "--show-source", + "--builtin=gettext", + "--max-complexity=20", + "--import-order-style=google", + "--exclude=.nox,.cache,env,lib,generated_pb2,*_pb2.py,*_pb2_grpc.py", + "--ignore=E121,E123,E126,E203,E226,E24,E266,E501,E704,W503,W504,I202", + "--max-line-length=88", +] + + +@nox.session +def lint(session): + session.install("flake8", "flake8-import-order") + + local_names = _determine_local_import_names(".") + args = FLAKE8_COMMON_ARGS + [ + "--application-import-names", + ",".join(local_names), + ".", + ] + session.run("flake8", *args) + + +# +# Sample Tests +# + + +PYTEST_COMMON_ARGS = ["--junitxml=sponge_log.xml"] + + +def _session_tests(session, post_install=None): + """Runs py.test for a particular project.""" + if os.path.exists("requirements.txt"): + session.install("-r", "requirements.txt") + + if os.path.exists("requirements-test.txt"): + session.install("-r", "requirements-test.txt") + + if INSTALL_LIBRARY_FROM_SOURCE: + session.install("-e", _get_repo_root()) + + if post_install: + post_install(session) + + session.run( + "pytest", + *(PYTEST_COMMON_ARGS + session.posargs), + # Pytest will return 5 when no tests are collected. This can happen + # on travis where slow and flaky tests are excluded. + # See http://doc.pytest.org/en/latest/_modules/_pytest/main.html + success_codes=[0, 5], + env=get_pytest_env_vars() + ) + + +@nox.session(python=ALL_VERSIONS) +def py(session): + """Runs py.test for a sample using the specified version of Python.""" + if session.python in TESTED_VERSIONS: + _session_tests(session) + else: + session.skip( + "SKIPPED: {} tests are disabled for this sample.".format(session.python) + ) + + +# +# Readmegen +# + + +def _get_repo_root(): + """ Returns the root folder of the project. """ + # Get root of this repository. Assume we don't have directories nested deeper than 10 items. + p = Path(os.getcwd()) + for i in range(10): + if p is None: + break + if Path(p / ".git").exists(): + return str(p) + p = p.parent + raise Exception("Unable to detect repository root.") + + +GENERATED_READMES = sorted([x for x in Path(".").rglob("*.rst.in")]) + + +@nox.session +@nox.parametrize("path", GENERATED_READMES) +def readmegen(session, path): + """(Re-)generates the readme for a sample.""" + session.install("jinja2", "pyyaml") + dir_ = os.path.dirname(path) + + if os.path.exists(os.path.join(dir_, "requirements.txt")): + session.install("-r", os.path.join(dir_, "requirements.txt")) + + in_file = os.path.join(dir_, "README.rst.in") + session.run( + "python", _get_repo_root() + "/scripts/readme-gen/readme_gen.py", in_file + ) diff --git a/samples/model-builder/predict_text_classification_single_label_sample.py b/samples/model-builder/predict_text_classification_single_label_sample.py new file mode 100644 index 0000000000..195b519750 --- /dev/null +++ b/samples/model-builder/predict_text_classification_single_label_sample.py @@ -0,0 +1,33 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_text_classification_single_label_sample] +def predict_text_classification_single_label_sample( + project, location, endpoint, content +): + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint) + + response = endpoint.predict(instances=[{"content": content}], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_text_classification_single_label_sample] diff --git a/samples/model-builder/predict_text_classification_single_label_sample_test.py b/samples/model-builder/predict_text_classification_single_label_sample_test.py new file mode 100644 index 0000000000..c446235a79 --- /dev/null +++ b/samples/model-builder/predict_text_classification_single_label_sample_test.py @@ -0,0 +1,37 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_text_classification_single_label_sample +import test_constants as constants + + +def test_predict_text_classification_single_label_sample( + mock_sdk_init, mock_get_endpoint +): + + predict_text_classification_single_label_sample.predict_text_classification_single_label_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint=constants.ENDPOINT_NAME, + content=constants.PREDICTION_TEXT_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) diff --git a/samples/model-builder/predict_text_entity_extraction_sample.py b/samples/model-builder/predict_text_entity_extraction_sample.py new file mode 100644 index 0000000000..577296333a --- /dev/null +++ b/samples/model-builder/predict_text_entity_extraction_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_text_entity_extraction_sample] +def predict_text_entity_extraction_sample(project, location, endpoint_id, content): + + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint_id) + + response = endpoint.predict(instances=[{"content": content}], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_text_entity_extraction_sample] diff --git a/samples/model-builder/predict_text_entity_extraction_sample_test.py b/samples/model-builder/predict_text_entity_extraction_sample_test.py new file mode 100644 index 0000000000..3ca2b49b43 --- /dev/null +++ b/samples/model-builder/predict_text_entity_extraction_sample_test.py @@ -0,0 +1,35 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_text_entity_extraction_sample +import test_constants as constants + + +def test_predict_text_entity_extraction_sample(mock_sdk_init, mock_get_endpoint): + + predict_text_entity_extraction_sample.predict_text_entity_extraction_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint_id=constants.ENDPOINT_NAME, + content=constants.PREDICTION_TEXT_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) diff --git a/samples/model-builder/predict_text_sentiment_analysis_sample.py b/samples/model-builder/predict_text_sentiment_analysis_sample.py new file mode 100644 index 0000000000..9fca0b4168 --- /dev/null +++ b/samples/model-builder/predict_text_sentiment_analysis_sample.py @@ -0,0 +1,32 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_predict_text_sentiment_analysis_sample] +def predict_text_sentiment_analysis_sample(project, location, endpoint_id, content): + + aiplatform.init(project=project, location=location) + + endpoint = aiplatform.Endpoint(endpoint_id) + + response = endpoint.predict(instances=[{"content": content}], parameters={}) + + for prediction_ in response.predictions: + print(prediction_) + + +# [END aiplatform_sdk_predict_text_sentiment_analysis_sample] diff --git a/samples/model-builder/predict_text_sentiment_analysis_sample_test.py b/samples/model-builder/predict_text_sentiment_analysis_sample_test.py new file mode 100644 index 0000000000..c2ed180c9f --- /dev/null +++ b/samples/model-builder/predict_text_sentiment_analysis_sample_test.py @@ -0,0 +1,35 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import predict_text_sentiment_analysis_sample +import test_constants as constants + + +def test_predict_text_sentiment_analysis_sample(mock_sdk_init, mock_get_endpoint): + + predict_text_sentiment_analysis_sample.predict_text_sentiment_analysis_sample( + project=constants.PROJECT, + location=constants.LOCATION, + endpoint_id=constants.ENDPOINT_NAME, + content=constants.PREDICTION_TEXT_INSTANCE, + ) + + mock_sdk_init.assert_called_once_with( + project=constants.PROJECT, location=constants.LOCATION + ) + + mock_get_endpoint.assert_called_once_with( + constants.ENDPOINT_NAME, + ) diff --git a/samples/model-builder/requirements-tests.txt b/samples/model-builder/requirements-tests.txt new file mode 100644 index 0000000000..f53c4c11a6 --- /dev/null +++ b/samples/model-builder/requirements-tests.txt @@ -0,0 +1 @@ +pytest >= 6.2 diff --git a/samples/model-builder/requirements.txt b/samples/model-builder/requirements.txt new file mode 100644 index 0000000000..efe811b2c3 --- /dev/null +++ b/samples/model-builder/requirements.txt @@ -0,0 +1,2 @@ +pytest >= 6.2 +git+https://github.com/googleapis/python-aiplatform.git@mb-release#egg=google-cloud-aiplatform \ No newline at end of file diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py new file mode 100644 index 0000000000..50dfa968b4 --- /dev/null +++ b/samples/model-builder/test_constants.py @@ -0,0 +1,53 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from random import randint +from uuid import uuid4 + +from google.auth import credentials + +PROJECT = "abc" +LOCATION = "us-central1" +LOCATION_EUROPE = "europe-west4" +LOCATION_ASIA = "asia-east1" +PARENT = f"projects/{PROJECT}/locations/{LOCATION}" + +DISPLAY_NAME = str(uuid4()) # Create random display name +DISPLAY_NAME_2 = str(uuid4()) + +STAGING_BUCKET = "gs://my-staging-bucket" +EXPERIMENT_NAME = "fraud-detection-trial-72" +CREDENTIALS = credentials.AnonymousCredentials() + +RESOURCE_ID = str(randint(10000000, 99999999)) # Create random resource ID +RESOURCE_ID_2 = str(randint(10000000, 99999999)) + +BATCH_PREDICTION_JOB_NAME = f"{PARENT}/batchPredictionJobs/{RESOURCE_ID}" +DATASET_NAME = f"{PARENT}/datasets/{RESOURCE_ID}" +ENDPOINT_NAME = f"{PARENT}/endpoints/{RESOURCE_ID}" +MODEL_NAME = f"{PARENT}/models/{RESOURCE_ID}" +TRAINING_JOB_NAME = f"{PARENT}/trainingJobs/{RESOURCE_ID}" + +GCS_SOURCES = ["gs://bucket1/source1.jsonl", "gs://bucket7/source4.jsonl"] +GCS_DESTINATION = "gs://bucket3/output-dir/" + +TRAINING_FRACTION_SPLIT = 0.7 +TEST_FRACTION_SPLIT = 0.15 +VALIDATION_FRACTION_SPLIT = 0.15 + +BUDGET_MILLI_NODE_HOURS_8000 = 8000 + +ENCRYPTION_SPEC_KEY_NAME = f"{PARENT}/keyRings/{RESOURCE_ID}/cryptoKeys/{RESOURCE_ID_2}" + +PREDICTION_TEXT_INSTANCE = "This is some text for testing NLP prediction output" diff --git a/samples/snippets/create_custom_job_sample_test.py b/samples/snippets/create_custom_job_sample_test.py index 212dd41e3c..0a29132cdc 100644 --- a/samples/snippets/create_custom_job_sample_test.py +++ b/samples/snippets/create_custom_job_sample_test.py @@ -22,7 +22,7 @@ import helpers PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") -CONTAINER_IMAGE_URI = "gcr.io/ucaip-test/ucaip-training-test:latest" +CONTAINER_IMAGE_URI = "gcr.io/ucaip-sample-tests/ucaip-training-test:latest" @pytest.fixture(scope="function", autouse=True) diff --git a/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py b/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py index f430fc38ed..7bb5ec5ac3 100644 --- a/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py +++ b/samples/snippets/create_hyperparameter_tuning_job_python_package_sample_test.py @@ -26,7 +26,7 @@ ) EXECUTOR_IMAGE_URI = "us.gcr.io/cloud-aiplatform/training/tf-gpu.2-1:latest" -PACKAGE_URI = "gs://ucaip-test-us-central1/training/pythonpackages/trainer.tar.bz2" +PACKAGE_URI = "gs://cloud-samples-data-us-central1/ai-platform-unified/training/python-packages/trainer.tar.bz2" PYTHON_MODULE = "trainer.hptuning_trainer" diff --git a/samples/snippets/create_hyperparameter_tuning_job_sample_test.py b/samples/snippets/create_hyperparameter_tuning_job_sample_test.py index ad1f0ae4db..9a16bdcb9c 100644 --- a/samples/snippets/create_hyperparameter_tuning_job_sample_test.py +++ b/samples/snippets/create_hyperparameter_tuning_job_sample_test.py @@ -21,7 +21,7 @@ import helpers PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") -CONTAINER_IMAGE_URI = "gcr.io/ucaip-test/ucaip-training-test:latest" +CONTAINER_IMAGE_URI = "gcr.io/ucaip-sample-tests/ucaip-training-test:latest" @pytest.fixture(scope="function", autouse=True) diff --git a/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py b/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py index 82725f3847..2323163c9e 100644 --- a/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py +++ b/samples/snippets/create_training_pipeline_custom_training_managed_dataset_sample_test.py @@ -30,7 +30,7 @@ ANNOTATION_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml" TRAINING_CONTAINER_SPEC_IMAGE_URI = ( - "gcr.io/ucaip-test/custom-container-managed-dataset:latest" + "gcr.io/ucaip-sample-tests/custom-container-managed-dataset:latest" ) MODEL_CONTAINER_SPEC_IMAGE_URI = "gcr.io/cloud-aiplatform/prediction/tf-gpu.1-15:latest" diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index b9fd33d5c1..481213275f 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,3 +1,3 @@ pytest==6.2.2 google-cloud-storage>=1.26.0, <2.0.0dev -google-cloud-aiplatform==0.5.1 +google-cloud-aiplatform==0.6.0 diff --git a/setup.py b/setup.py index cc19d7a867..771eb17ee9 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ import setuptools # type: ignore name = "google-cloud-aiplatform" -version = "0.6.0" +version = "0.7.0" description = "Cloud AI Platform API client library" package_root = os.path.abspath(os.path.dirname(__file__)) @@ -46,6 +46,7 @@ "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", "proto-plus >= 1.10.1", "google-cloud-storage >= 1.26.0, < 2.0.0dev", + "google-cloud-bigquery >= 1.15.0, < 3.0.0dev", ), python_requires=">=3.6", scripts=[], diff --git a/tests/system/aiplatform/test_dataset.py b/tests/system/aiplatform/test_dataset.py new file mode 100644 index 0000000000..e18390a76a --- /dev/null +++ b/tests/system/aiplatform/test_dataset.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import uuid +import pytest +import importlib + +from google import auth as google_auth +from google.protobuf import json_format +from google.api_core import exceptions +from google.api_core import client_options + +from google.cloud import storage +from google.cloud import aiplatform +from google.cloud.aiplatform import utils +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.services import dataset_service + +from test_utils.vpcsc_config import vpcsc_config + +# TODO(vinnys): Replace with env var `BUILD_SPECIFIC_GCP_PROJECT` once supported +_, _TEST_PROJECT = google_auth.default() +TEST_BUCKET = os.environ.get( + "GCLOUD_TEST_SAMPLES_BUCKET", "cloud-samples-data-us-central1" +) + +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com" +_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset +_TEST_TEXT_DATASET_ID = ( + "6203215905493614592" # permanent_text_entity_extraction_dataset +) +_TEST_DATASET_DISPLAY_NAME = "permanent_50_flowers_dataset" +_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE = "gs://ucaip-sample-resources/iris_1000.csv" +_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE = f"gs://{TEST_BUCKET}/ai-platform-unified/sdk/datasets/text_entity_extraction_dataset.jsonl" +_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE = ( + "gs://ucaip-test-us-central1/dataset/salads_oid_ml_use_public_unassigned.jsonl" +) +_TEST_TEXT_ENTITY_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/text_extraction_io_format_1.0.0.yaml" +_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA = "gs://google-cloud-aiplatform/schema/dataset/ioformat/image_bounding_box_io_format_1.0.0.yaml" + + +class TestDataset: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + @pytest.fixture() + def shared_state(self): + shared_state = {} + yield shared_state + + @pytest.fixture() + def create_staging_bucket(self, shared_state): + new_staging_bucket = f"temp-sdk-integration-{uuid.uuid4()}" + + storage_client = storage.Client() + storage_client.create_bucket(new_staging_bucket) + shared_state["storage_client"] = storage_client + shared_state["staging_bucket"] = new_staging_bucket + yield + + @pytest.fixture() + def delete_staging_bucket(self, shared_state): + yield + storage_client = shared_state["storage_client"] + + # Delete temp staging bucket + bucket_to_delete = storage_client.get_bucket(shared_state["staging_bucket"]) + bucket_to_delete.delete(force=True) + + # Close Storage Client + storage_client._http._auth_request.session.close() + storage_client._http.close() + + @pytest.fixture() + def dataset_gapic_client(self): + gapic_client = dataset_service.DatasetServiceClient( + client_options=client_options.ClientOptions(api_endpoint=_TEST_API_ENDPOINT) + ) + + yield gapic_client + + @pytest.fixture() + def create_text_dataset(self, dataset_gapic_client, shared_state): + + gapic_dataset = gca_dataset.Dataset( + display_name=f"temp_sdk_integration_test_create_text_dataset_{uuid.uuid4()}", + metadata_schema_uri=aiplatform.schema.dataset.metadata.text, + ) + + create_lro = dataset_gapic_client.create_dataset( + parent=_TEST_PARENT, dataset=gapic_dataset + ) + new_dataset = create_lro.result() + shared_state["dataset_name"] = new_dataset.name + yield + + @pytest.fixture() + def create_tabular_dataset(self, dataset_gapic_client, shared_state): + + gapic_dataset = gca_dataset.Dataset( + display_name=f"temp_sdk_integration_test_create_tabular_dataset_{uuid.uuid4()}", + metadata_schema_uri=aiplatform.schema.dataset.metadata.tabular, + ) + + create_lro = dataset_gapic_client.create_dataset( + parent=_TEST_PARENT, dataset=gapic_dataset + ) + new_dataset = create_lro.result() + shared_state["dataset_name"] = new_dataset.name + yield + + @pytest.fixture() + def create_image_dataset(self, dataset_gapic_client, shared_state): + + gapic_dataset = gca_dataset.Dataset( + display_name=f"temp_sdk_integration_test_create_image_dataset_{uuid.uuid4()}", + metadata_schema_uri=aiplatform.schema.dataset.metadata.image, + ) + + create_lro = dataset_gapic_client.create_dataset( + parent=_TEST_PARENT, dataset=gapic_dataset + ) + new_dataset = create_lro.result() + shared_state["dataset_name"] = new_dataset.name + yield + + @pytest.fixture() + def delete_new_dataset(self, dataset_gapic_client, shared_state): + yield + assert shared_state["dataset_name"] + + deletion_lro = dataset_gapic_client.delete_dataset( + name=shared_state["dataset_name"] + ) + deletion_lro.result() + + shared_state["dataset_name"] = None + + # TODO(vinnys): Remove pytest skip once persistent resources are accessible + @pytest.mark.skip(reason="System tests cannot access persistent test resources") + def test_get_existing_dataset(self): + """Retrieve a known existing dataset, ensure SDK successfully gets the + dataset resource.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + flowers_dataset = aiplatform.ImageDataset(dataset_name=_TEST_IMAGE_DATASET_ID) + assert flowers_dataset.name == _TEST_IMAGE_DATASET_ID + assert flowers_dataset.display_name == _TEST_DATASET_DISPLAY_NAME + + def test_get_nonexistent_dataset(self): + """Ensure attempting to retrieve a dataset that doesn't exist raises + a Google API core 404 exception.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # AI Platform service returns 404 + with pytest.raises(exceptions.NotFound): + aiplatform.ImageDataset(dataset_name="0") + + @pytest.mark.usefixtures("create_text_dataset", "delete_new_dataset") + def test_get_new_dataset_and_import(self, dataset_gapic_client, shared_state): + """Retrieve new, empty dataset and import a text dataset using import(). + Then verify data items were successfully imported.""" + + assert shared_state["dataset_name"] + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_dataset = aiplatform.TextDataset(dataset_name=shared_state["dataset_name"]) + + data_items_pre_import = dataset_gapic_client.list_data_items( + parent=my_dataset.resource_name + ) + + assert len(list(data_items_pre_import)) == 0 + + # Blocking call to import + my_dataset.import_data( + gcs_source=_TEST_TEXT_ENTITY_EXTRACTION_GCS_SOURCE, + import_schema_uri=_TEST_TEXT_ENTITY_IMPORT_SCHEMA, + ) + + data_items_post_import = dataset_gapic_client.list_data_items( + parent=my_dataset.resource_name + ) + + assert len(list(data_items_post_import)) == 469 + + @vpcsc_config.skip_if_inside_vpcsc + @pytest.mark.usefixtures("delete_new_dataset") + def test_create_and_import_image_dataset(self, dataset_gapic_client, shared_state): + """Use the Dataset.create() method to create a new image obj detection + dataset and import images. Then confirm images were successfully imported.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + img_dataset = aiplatform.ImageDataset.create( + display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", + gcs_source=_TEST_IMAGE_OBJECT_DETECTION_GCS_SOURCE, + import_schema_uri=_TEST_IMAGE_OBJ_DET_IMPORT_SCHEMA, + ) + + shared_state["dataset_name"] = img_dataset.resource_name + + data_items_iterator = dataset_gapic_client.list_data_items( + parent=img_dataset.resource_name + ) + + assert len(list(data_items_iterator)) == 14 + + @pytest.mark.usefixtures("delete_new_dataset") + def test_create_tabular_dataset(self, dataset_gapic_client, shared_state): + """Use the Dataset.create() method to create a new tabular dataset. + Then confirm the dataset was successfully created and references GCS source.""" + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + tabular_dataset = aiplatform.TabularDataset.create( + display_name=f"temp_sdk_integration_create_and_import_dataset_{uuid.uuid4()}", + gcs_source=[_TEST_TABULAR_CLASSIFICATION_GCS_SOURCE], + ) + + gapic_dataset = tabular_dataset._gca_resource + shared_state["dataset_name"] = tabular_dataset.resource_name + + gapic_metadata = json_format.MessageToDict(gapic_dataset._pb.metadata) + gcs_source_uris = gapic_metadata["inputConfig"]["gcsSource"]["uri"] + + assert len(gcs_source_uris) == 1 + assert _TEST_TABULAR_CLASSIFICATION_GCS_SOURCE == gcs_source_uris[0] + assert ( + gapic_dataset.metadata_schema_uri + == aiplatform.schema.dataset.metadata.tabular + ) + + # TODO(vinnys): Remove pytest skip once persistent resources are accessible + @pytest.mark.skip(reason="System tests cannot access persistent test resources") + @pytest.mark.usefixtures("create_staging_bucket", "delete_staging_bucket") + def test_export_data(self, shared_state): + """Get an existing dataset, export data to a newly created folder in + Google Cloud Storage, then verify data was successfully exported.""" + + assert shared_state["staging_bucket"] + assert shared_state["storage_client"] + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=f"gs://{shared_state['staging_bucket']}", + ) + + text_dataset = aiplatform.TextDataset(dataset_name=_TEST_TEXT_DATASET_ID) + + exported_files = text_dataset.export_data( + output_dir=f"gs://{shared_state['staging_bucket']}" + ) + + assert len(exported_files) # Ensure at least one GCS path was returned + + exported_file = exported_files[0] + bucket, prefix = utils.extract_bucket_and_prefix_from_gcs_path(exported_file) + + storage_client = shared_state["storage_client"] + + bucket = storage_client.get_bucket(bucket) + blob = bucket.get_blob(prefix) + + assert blob # Verify the returned GCS export path exists diff --git a/tests/unit/aiplatform/test_automl_image_training_jobs.py b/tests/unit/aiplatform/test_automl_image_training_jobs.py new file mode 100644 index 0000000000..ec0de7140b --- /dev/null +++ b/tests/unit/aiplatform/test_automl_image_training_jobs.py @@ -0,0 +1,434 @@ +import pytest +import importlib +from unittest import mock + +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_IMAGE = schema.dataset.metadata.image + +_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS = 1000 +_TEST_TRAINING_DISABLE_EARLY_STOPPING = True +_TEST_MODEL_TYPE_ICN = "CLOUD" # Image Classification default +_TEST_MODEL_TYPE_IOD = "CLOUD_HIGH_ACCURACY_1" # Image Object Detection default +_TEST_MODEL_TYPE_MOBILE = "MOBILE_TF_LOW_LATENCY_1" +_TEST_PREDICTION_TYPE_ICN = "classification" +_TEST_PREDICTION_TYPE_IOD = "object_detection" + +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_MODEL_ID = "98777645321" + +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + { + "modelType": "CLOUD", + "budgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + "multiLabel": False, + "disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING, + }, + struct_pb2.Value(), +) + +_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL = json_format.ParseDict( + { + "modelType": "CLOUD", + "budgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + "multiLabel": False, + "disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING, + "baseModelId": _TEST_MODEL_ID, + }, + struct_pb2.Value(), +) + +_TEST_FRACTION_SPLIT_TRAINING = 0.6 +_TEST_FRACTION_SPLIT_VALIDATION = 0.2 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" +) + +_TEST_PIPELINE_RESOURCE_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_image(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_model_image(): + model = mock.MagicMock(models.Model) + model.name = _TEST_MODEL_ID + model._latest_future = None + model._exception = None + model._gca_resource = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description="This is the mock Model's description", + name=_TEST_MODEL_NAME, + ) + yield model + + +class TestAutoMLImageTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_all_parameters(self, mock_model_image): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_ICN, + model_type=_TEST_MODEL_TYPE_MOBILE, + base_model=mock_model_image, + multi_label=True, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert job._model_type == _TEST_MODEL_TYPE_MOBILE + assert job._prediction_type == _TEST_PREDICTION_TYPE_ICN + assert job._multi_label is True + assert job._base_model == mock_model_image + + def test_init_wrong_parameters(self, mock_model_image): + """Ensure correct exceptions are raised when initializing with invalid args""" + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(ValueError, match=r"not a supported prediction type"): + training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, prediction_type="abcdefg", + ) + + with pytest.raises(ValueError, match=r"not a supported model_type for"): + training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type="classification", + model_type=_TEST_MODEL_TYPE_IOD, + ) + + with pytest.raises(ValueError, match=r"`base_model` is only supported"): + training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_IOD, + base_model=mock_model_image, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_image, + mock_model_service_get, + mock_model_image, + sync, + ): + """Create and run an AutoML ICN training job, verify calls and return value""" + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, base_model=mock_model_image + ) + + model_from_job = job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model_image._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_image.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_image_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_BASE_MODEL, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_image, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_image, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_image.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_image_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises(self, mock_dataset_image, sync): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_image, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_image, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_image, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLImageTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state diff --git a/tests/unit/aiplatform/test_automl_tabular_training_jobs.py b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py new file mode 100644 index 0000000000..62cab4b3c3 --- /dev/null +++ b/tests/unit/aiplatform/test_automl_tabular_training_jobs.py @@ -0,0 +1,441 @@ +import importlib +import pytest +from unittest import mock + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +_TEST_BUCKET_NAME = "test-bucket" +_TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder" +_TEST_GCS_PATH = f"{_TEST_BUCKET_NAME}/{_TEST_GCS_PATH_WITHOUT_BUCKET}" +_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/" +_TEST_PROJECT = "test-project" + +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" +_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular +_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image + +_TEST_TRAINING_COLUMN_TRANSFORMATIONS = [ + {"auto": {"column_name": "sepal_width"}}, + {"auto": {"column_name": "sepal_length"}}, + {"auto": {"column_name": "petal_length"}}, + {"auto": {"column_name": "petal_width"}}, +] +_TEST_TRAINING_TARGET_COLUMN = "target" +_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS = 1000 +_TEST_TRAINING_WEIGHT_COLUMN = "weight" +_TEST_TRAINING_DISABLE_EARLY_STOPPING = True +_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME = "minimize-log-loss" +_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE = "classification" +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + { + # required inputs + "targetColumn": _TEST_TRAINING_TARGET_COLUMN, + "transformations": _TEST_TRAINING_COLUMN_TRANSFORMATIONS, + "trainBudgetMilliNodeHours": _TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + # optional inputs + "weightColumnName": _TEST_TRAINING_WEIGHT_COLUMN, + "disableEarlyStopping": _TEST_TRAINING_DISABLE_EARLY_STOPPING, + "predictionType": _TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + "optimizationObjective": _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + "optimizationObjectiveRecallValue": None, + "optimizationObjectivePrecisionValue": None, + }, + struct_pb2.Value(), +) + +_TEST_DATASET_NAME = "test-dataset-name" + +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_TRAINING_FRACTION_SPLIT = 0.6 +_TEST_VALIDATION_FRACTION_SPLIT = 0.2 +_TEST_TEST_FRACTION_SPLIT = 0.2 +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" + +_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" + +_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345" + +_TEST_PIPELINE_RESOURCE_NAME = ( + "projects/my-project/locations/us-central1/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_tabular(): + ds = mock.MagicMock(datasets.TabularDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_dataset_nontabular(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +class TestAutoMLTabularTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_tabular, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_tabular.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_tabular, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises(self, mock_dataset_tabular, sync): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_tabular, sync + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_tabular, + target_column=_TEST_TRAINING_TARGET_COLUMN, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.AutoMLTabularTrainingJob( + display_name=_TEST_DISPLAY_NAME, + optimization_prediction_type=_TEST_TRAINING_OPTIMIZATION_PREDICTION_TYPE, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + optimization_objective_recall_value=None, + optimization_objective_precision_value=None, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state diff --git a/tests/unit/aiplatform/test_automl_text_training_jobs.py b/tests/unit/aiplatform/test_automl_text_training_jobs.py new file mode 100644 index 0000000000..101ff79ef5 --- /dev/null +++ b/tests/unit/aiplatform/test_automl_text_training_jobs.py @@ -0,0 +1,618 @@ +import pytest +import importlib +from unittest import mock + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) +from google.cloud.aiplatform.v1.schema.trainingjob import ( + definition_v1 as training_job_inputs, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_TEXT = schema.dataset.metadata.text + +_TEST_PREDICTION_TYPE_CLASSIFICATION = "classification" +_TEST_CLASSIFICATION_MULTILABEL = True +_TEST_PREDICTION_TYPE_EXTRACTION = "extraction" +_TEST_PREDICTION_TYPE_SENTIMENT = "sentiment" +_TEST_SENTIMENT_MAX = 10 + +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_MODEL_ID = "98777645321" + +_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION = training_job_inputs.AutoMlTextClassificationInputs( + multi_label=_TEST_CLASSIFICATION_MULTILABEL +) +_TEST_TRAINING_TASK_INPUTS_EXTRACTION = training_job_inputs.AutoMlTextExtractionInputs() +_TEST_TRAINING_TASK_INPUTS_SENTIMENT = training_job_inputs.AutoMlTextSentimentInputs( + sentiment_max=_TEST_SENTIMENT_MAX +) + +_TEST_FRACTION_SPLIT_TRAINING = 0.6 +_TEST_FRACTION_SPLIT_VALIDATION = 0.2 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" +) + +_TEST_PIPELINE_RESOURCE_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_text(): + ds = mock.MagicMock(datasets.TextDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_model(): + model = mock.MagicMock(models.Model) + model.name = _TEST_MODEL_ID + model._latest_future = None + model._gca_resource = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, name=_TEST_MODEL_NAME, + ) + yield model + + +class TestAutoMLTextTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_all_parameters_classification(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert ( + job._training_task_definition + == schema.training_job.definition.automl_text_classification + ) + assert ( + job._training_task_inputs_dict + == training_job_inputs.AutoMlTextClassificationInputs( + multi_label=_TEST_CLASSIFICATION_MULTILABEL + ) + ) + + def test_init_all_parameters_extraction(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_EXTRACTION, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert ( + job._training_task_definition + == schema.training_job.definition.automl_text_extraction + ) + assert ( + job._training_task_inputs_dict + == training_job_inputs.AutoMlTextExtractionInputs() + ) + + def test_init_all_parameters_sentiment(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_SENTIMENT, + sentiment_max=_TEST_SENTIMENT_MAX, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert ( + job._training_task_definition + == schema.training_job.definition.automl_text_sentiment + ) + assert ( + job._training_task_inputs_dict + == training_job_inputs.AutoMlTextSentimentInputs( + sentiment_max=_TEST_SENTIMENT_MAX + ) + ) + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_training_job( + self, + mock_pipeline_service_create, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Text Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_classification( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """Create and run an AutoML Text Classification training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_extraction( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """Create and run an AutoML Text Extraction training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_EXTRACTION, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_extraction, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_EXTRACTION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_sentiment( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_text, + mock_model_service_get, + sync, + ): + """Create and run an AutoML Text Sentiment training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_SENTIMENT, + sentiment_max=10, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model(display_name=_TEST_MODEL_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_sentiment, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_SENTIMENT, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_text, + mock_model_service_get, + mock_model, + sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type="classification", + multi_label=True, + ) + + model_from_job = job.run( + dataset=mock_dataset_text, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + model_display_name=None, # Omit model_display_name + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_text.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_text_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_CLASSIFICATION, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises(self, mock_dataset_text, sync): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type="classification", + multi_label=True, + ) + + job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_text, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_text, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLTextTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_CLASSIFICATION, + multi_label=_TEST_CLASSIFICATION_MULTILABEL, + ) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_text, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + validation_fraction_split=_TEST_FRACTION_SPLIT_VALIDATION, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() diff --git a/tests/unit/aiplatform/test_automl_video_training_jobs.py b/tests/unit/aiplatform/test_automl_video_training_jobs.py new file mode 100644 index 0000000000..66f1692fcf --- /dev/null +++ b/tests/unit/aiplatform/test_automl_video_training_jobs.py @@ -0,0 +1,463 @@ +import pytest +import importlib +from unittest import mock + +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_VIDEO = schema.dataset.metadata.video + +_TEST_MODEL_TYPE_CLOUD = "CLOUD" +_TEST_MODEL_TYPE_MOBILE = "MOBILE_VERSATILE_1" + +_TEST_PREDICTION_TYPE_VAR = "action_recognition" +_TEST_PREDICTION_TYPE_VCN = "classification" +_TEST_PREDICTION_TYPE_VOR = "object_tracking" + +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_MODEL_ID = "98777645321" # TODO + +_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( + {"modelType": "CLOUD"}, struct_pb2.Value(), +) + +_TEST_FRACTION_SPLIT_TRAINING = 0.8 +_TEST_FRACTION_SPLIT_TEST = 0.2 + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_ID}" +) + +_TEST_PIPELINE_RESOURCE_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipeline/12345" +) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model() + yield mock_get_model + + +@pytest.fixture +def mock_dataset_video(): + ds = mock.MagicMock(datasets.VideoDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_model(): + model = mock.MagicMock(models.Model) + model.name = _TEST_MODEL_ID + model._latest_future = None + model._exception = None + model._gca_resource = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, name=_TEST_MODEL_NAME, + ) + yield model + + +class TestAutoMLVideoTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_all_parameters(self): + """Ensure all private members are set correctly at initalization""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + assert job._display_name == _TEST_DISPLAY_NAME + assert job._model_type == _TEST_MODEL_TYPE_CLOUD + assert job._prediction_type == _TEST_PREDICTION_TYPE_VCN + + def test_init_wrong_parameters(self): + """Ensure correct exceptions are raised when initializing with invalid args""" + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(ValueError, match=r"not a supported prediction type"): + training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, prediction_type="abcdefg", + ) + + with pytest.raises(ValueError, match=r"not a supported model_type for"): + training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type="abcdefg", + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_training_job( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_video, + mock_model_service_get, + mock_model, + sync, + ): + """ + Initiate aiplatform with encryption key name. + Create and run an AutoML Video Classification training job, verify calls and return value + """ + + aiplatform.init( + project=_TEST_PROJECT, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_dataset_video, + mock_model_service_get, + mock_model, + sync, + ): + """Create and run an AutoML ICN training job, verify calls and return value""" + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=mock_model._gca_resource.description, + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + assert job._gca_resource is mock_pipeline_service_get.return_value + assert model_from_job._gca_resource is mock_model_service_get.return_value + assert job.get_model()._gca_resource is mock_model_service_get.return_value + assert not job.has_failed + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_if_no_model_display_name( + self, + mock_pipeline_service_create, + mock_dataset_video, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob( + display_name=_TEST_DISPLAY_NAME, + prediction_type=_TEST_PREDICTION_TYPE_VCN, + model_type=_TEST_MODEL_TYPE_CLOUD, + ) + + model_from_job = job.run( + dataset=mock_dataset_video, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + ) + + if not sync: + model_from_job.wait() + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction=_TEST_FRACTION_SPLIT_TEST, + ) + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, dataset_id=mock_dataset_video.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.automl_video_classification, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_dataset_video, sync, + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_dataset_video, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, mock_pipeline_service_create_and_get_with_fail, mock_dataset_video, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + dataset=mock_dataset_video, + training_fraction_split=_TEST_FRACTION_SPLIT_TRAINING, + test_fraction_split=_TEST_FRACTION_SPLIT_TEST, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.AutoMLVideoTrainingJob(display_name=_TEST_DISPLAY_NAME,) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state diff --git a/tests/unit/aiplatform/test_base.py b/tests/unit/aiplatform/test_base.py new file mode 100644 index 0000000000..97f35b9476 --- /dev/null +++ b/tests/unit/aiplatform/test_base.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +import pytest +import time +from typing import Optional + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer + + +class _TestClass(base.FutureManager): + def __init__(self, x): + self.x = x + super().__init__() + + @classmethod + def _empty_constructor(cls): + self = cls.__new__(cls) + base.FutureManager.__init__(self) + self.x = None + return self + + def _sync_object_with_future_result(self, result): + self.x = result.x + + @classmethod + @base.optional_sync() + def create(cls, x: int, sync=True) -> "_TestClass": + time.sleep(1) + return cls(x) + + @base.optional_sync() + def add(self, a: "_TestClass", sync=True) -> None: + time.sleep(1) + return self._add(a=a, sync=sync) + + def _add(self, a: "_TestClass", sync=True) -> None: + self.x = self.x + a.x + + +class _TestClassDownStream(_TestClass): + @base.optional_sync(construct_object_on_arg="a") + def add_and_create_new( + self, a: Optional["_TestClass"] = None, sync=True + ) -> _TestClass: + time.sleep(1) + if a: + return _TestClass(self.x + a.x) + return None + + @base.optional_sync(return_input_arg="a", bind_future_to_self=False) + def add_to_input_arg(self, a: "_TestClass", sync=True) -> _TestClass: + time.sleep(1) + a._add(self) + return a + + +class TestFutureManager: + def setup_method(self): + reload(initializer) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_task(self, sync): + a = _TestClass.create(10, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + a.wait() + assert a._latest_future is None + assert a.x == 10 + assert isinstance(a, _TestClass) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_task(self, sync): + _latest_future = None + + a = _TestClass.create(10, sync=sync) + b = _TestClass.create(7, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + b.add(a, sync=sync) + + if not sync: + assert b._latest_future is not _latest_future + b.wait() + + assert a._latest_future is None + assert a.x == 10 + assert b._latest_future is None + assert b.x == 17 + assert isinstance(a, _TestClass) + assert isinstance(b, _TestClass) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_and_create_new_task(self, sync): + _latest_future = None + + a = _TestClass.create(10, sync=sync) + b = _TestClassDownStream.create(7, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + c = b.add_and_create_new(a, sync=sync) + + if not sync: + assert b._latest_future is not _latest_future + assert c.x is None + assert c._latest_future is not None + c.wait() + + assert a._latest_future is None + assert a.x == 10 + assert b._latest_future is None + assert b.x == 7 + assert c._latest_future is None + assert c.x == 17 + assert isinstance(a, _TestClass) + assert isinstance(b, _TestClassDownStream) + assert isinstance(c, _TestClass) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_and_not_create_new_task(self, sync): + _latest_future = None + + b = _TestClassDownStream.create(7, sync=sync) + if not sync: + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + c = b.add_and_create_new(None, sync=sync) + + if not sync: + assert b._latest_future is not _latest_future + b.wait() + + assert c is None + + assert b._latest_future is None + assert b.x == 7 + assert isinstance(b, _TestClassDownStream) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_add_return_arg(self, sync): + _latest_future = None + + a = _TestClass.create(10, sync=sync) + b = _TestClassDownStream.create(7, sync=sync) + if not sync: + assert a.x is None + assert a._latest_future is not None + assert b.x is None + assert b._latest_future is not None + _latest_future = b._latest_future + + c = b.add_to_input_arg(a, sync=sync) + + if not sync: + assert b._latest_future is _latest_future + assert c.x is None + assert c._latest_future is not None + assert c is a + c.wait() + + assert a._latest_future is None + assert a.x == 17 + assert b._latest_future is None + assert b.x == 7 + assert c._latest_future is None + assert c.x == 17 + assert isinstance(a, _TestClass) + assert isinstance(b, _TestClassDownStream) + assert isinstance(c, _TestClass) diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py new file mode 100644 index 0000000000..5d1f92fe79 --- /dev/null +++ b/tests/unit/aiplatform/test_datasets.py @@ -0,0 +1,1176 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import pytest + +from unittest import mock +from importlib import reload +from unittest.mock import patch + +from google.api_core import operation +from google.auth.exceptions import GoogleAuthError +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema + +from google.cloud.aiplatform_v1.services.dataset_service import ( + client as dataset_service_client, +) + +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + dataset_service as gca_dataset_service, + encryption_spec as gca_encryption_spec, + io as gca_io, +) + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_ALT_PROJECT = "test-project_alt" + +_TEST_ALT_LOCATION = "europe-west4" +_TEST_INVALID_LOCATION = "us-central2" + +# dataset +_TEST_ID = "1028944691210842416" +_TEST_DISPLAY_NAME = "my_dataset_1234" +_TEST_DATA_LABEL_ITEMS = None + +_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}" +_TEST_ALT_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_ALT_LOCATION}/datasets/{_TEST_ID}" +) +_TEST_INVALID_NAME = f"prj/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/{_TEST_ID}" + +# metadata_schema_uri +_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular +_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image +_TEST_METADATA_SCHEMA_URI_IMAGE = schema.dataset.metadata.image +_TEST_METADATA_SCHEMA_URI_TEXT = schema.dataset.metadata.text +_TEST_METADATA_SCHEMA_URI_VIDEO = schema.dataset.metadata.video + +# import_schema_uri +_TEST_IMPORT_SCHEMA_URI_IMAGE = ( + schema.dataset.ioformat.image.single_label_classification +) +_TEST_IMPORT_SCHEMA_URI_TEXT = schema.dataset.ioformat.text.single_label_classification +_TEST_IMPORT_SCHEMA_URI = schema.dataset.ioformat.image.single_label_classification +_TEST_IMPORT_SCHEMA_URI_VIDEO = schema.dataset.ioformat.video.classification + +# datasources +_TEST_SOURCE_URI_GCS = "gs://my-bucket/my_index_file.jsonl" +_TEST_SOURCE_URIS_GCS = [ + "gs://my-bucket/index_file_1.jsonl", + "gs://my-bucket/index_file_2.jsonl", + "gs://my-bucket/index_file_3.jsonl", +] +_TEST_SOURCE_URI_BQ = "bigquery://my-project/my-dataset" +_TEST_INVALID_SOURCE_URIS = ["gs://my-bucket/index_file_1.jsonl", 123] + +# request_metadata +_TEST_REQUEST_METADATA = () + +# dataset_metadata +_TEST_NONTABULAR_DATASET_METADATA = None +_TEST_METADATA_TABULAR_GCS = { + "input_config": {"gcs_source": {"uri": [_TEST_SOURCE_URI_GCS]}} +} +_TEST_METADATA_TABULAR_BQ = { + "input_config": {"bigquery_source": {"uri": _TEST_SOURCE_URI_BQ}} +} + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + +# misc +_TEST_OUTPUT_DIR = "gs://my-output-bucket" + +_TEST_DATASET_LIST = [ + gca_dataset.Dataset( + display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + gca_dataset.Dataset( + display_name="d", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR + ), + gca_dataset.Dataset( + display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + gca_dataset.Dataset( + display_name="e", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT + ), + gca_dataset.Dataset( + display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY = "create_time desc" + + +@pytest.fixture +def get_dataset_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + name=_TEST_NAME, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_without_name_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_image_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_tabular_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_text_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def get_dataset_video_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "get_dataset" + ) as get_dataset_mock: + get_dataset_mock.return_value = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + name=_TEST_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_dataset_mock + + +@pytest.fixture +def create_dataset_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "create_dataset" + ) as create_dataset_mock: + create_dataset_lro_mock = mock.Mock(operation.Operation) + create_dataset_lro_mock.result.return_value = gca_dataset.Dataset( + name=_TEST_NAME, + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.return_value = create_dataset_lro_mock + yield create_dataset_mock + + +@pytest.fixture +def delete_dataset_mock(): + with mock.patch.object( + dataset_service_client.DatasetServiceClient, "delete_dataset" + ) as delete_dataset_mock: + delete_dataset_lro_mock = mock.Mock(operation.Operation) + delete_dataset_lro_mock.result.return_value = ( + gca_dataset_service.DeleteDatasetRequest() + ) + delete_dataset_mock.return_value = delete_dataset_lro_mock + yield delete_dataset_mock + + +@pytest.fixture +def import_data_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "import_data" + ) as import_data_mock: + import_data_mock.return_value = mock.Mock(operation.Operation) + yield import_data_mock + + +@pytest.fixture +def export_data_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "export_data" + ) as export_data_mock: + export_data_mock.return_value = mock.Mock(operation.Operation) + yield export_data_mock + + +@pytest.fixture +def list_datasets_mock(): + with patch.object( + dataset_service_client.DatasetServiceClient, "list_datasets" + ) as list_datasets_mock: + list_datasets_mock.return_value = _TEST_DATASET_LIST + yield list_datasets_mock + + +# TODO(b/171333554): Move reusable test fixtures to conftest.py file +class TestDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset(dataset_name=_TEST_NAME) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_id_only_with_project_and_location( + self, get_dataset_mock + ): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset( + dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_project_and_location(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset( + dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets._Dataset( + dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION + ) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + def test_init_dataset_with_project_and_alt_location(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(RuntimeError): + datasets._Dataset( + dataset_name=_TEST_NAME, + project=_TEST_PROJECT, + location=_TEST_ALT_LOCATION, + ) + + def test_init_dataset_with_id_only(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + datasets._Dataset(dataset_name=_TEST_ID) + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_without_name_mock") + @patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "", "GOOGLE_APPLICATION_CREDENTIALS": ""} + ) + def test_init_dataset_with_id_only_without_project_or_location(self): + with pytest.raises(GoogleAuthError): + datasets._Dataset( + dataset_name=_TEST_ID, + credentials=auth_credentials.AnonymousCredentials(), + ) + + def test_init_dataset_with_location_override(self, get_dataset_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + datasets._Dataset(dataset_name=_TEST_ID, location=_TEST_ALT_LOCATION) + get_dataset_mock.assert_called_once_with(name=_TEST_ALT_NAME) + + @pytest.mark.usefixtures("get_dataset_mock") + def test_init_dataset_with_invalid_name(self): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + datasets._Dataset(dataset_name=_TEST_INVALID_NAME) + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_dataset( + self, create_dataset_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset_nontabular(self, create_dataset_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_mock") + def test_create_dataset_tabular(self, create_dataset_mock): + aiplatform.init(project=_TEST_PROJECT) + + datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + bq_source=_TEST_SOURCE_URI_BQ, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=_TEST_SOURCE_URI_GCS, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=_TEST_SOURCE_URI_GCS, + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.usefixtures("get_dataset_mock") + def test_export_data(self, export_data_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset(dataset_name=_TEST_NAME) + + my_dataset.export_data(output_dir=_TEST_OUTPUT_DIR) + + expected_export_config = gca_dataset.ExportDataConfig( + gcs_destination=gca_io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_DIR) + ) + + export_data_mock.assert_called_once_with( + name=_TEST_NAME, export_config=expected_export_config + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets._Dataset.create( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=_TEST_SOURCE_URI_GCS, + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI, + data_item_labels=_TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_mock.assert_called_once_with(name=_TEST_NAME) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_dataset(self, delete_dataset_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + my_dataset.delete(sync=sync) + + if not sync: + my_dataset.wait() + + delete_dataset_mock.assert_called_once_with(name=my_dataset.resource_name) + + +class TestImageDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_image(self, get_dataset_image_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.ImageDataset(dataset_name=_TEST_NAME) + get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + def test_init_dataset_non_image(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): + datasets.ImageDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_image_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets.ImageDataset.create( + display_name=_TEST_DISPLAY_NAME, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_image_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.ImageDataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + ) + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_image_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.ImageDataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_image_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.ImageDataset.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_IMAGE, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_image_mock.assert_called_once_with(name=_TEST_NAME) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_IMAGE, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + +class TestTabularDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_tabular(self, get_dataset_tabular_mock): + + datasets.TabularDataset(dataset_name=_TEST_NAME) + get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_image_mock") + def test_init_dataset_non_tabular(self): + + with pytest.raises(ValueError): + datasets.TabularDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset_with_default_encryption_key( + self, create_dataset_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets.TabularDataset.create( + display_name=_TEST_DISPLAY_NAME, bq_source=_TEST_SOURCE_URI_BQ, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + + my_dataset = datasets.TabularDataset.create( + display_name=_TEST_DISPLAY_NAME, + bq_source=_TEST_SOURCE_URI_BQ, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + metadata=_TEST_METADATA_TABULAR_BQ, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + def test_no_import_data_method(self): + + my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) + + with pytest.raises(NotImplementedError): + my_dataset.import_data() + + def test_list_dataset(self, list_datasets_mock): + + ds_list = aiplatform.TabularDataset.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY + ) + + list_datasets_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + # Ensure returned list is smaller since it filtered out non-tabular datasets + assert len(ds_list) < len(_TEST_DATASET_LIST) + + for ds in ds_list: + assert type(ds) == aiplatform.TabularDataset + + +class TestTextDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_text(self, get_dataset_text_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.TextDataset(dataset_name=_TEST_NAME) + get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_image_mock") + def test_init_dataset_non_text(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): + datasets.TextDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_text_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = datasets.TextDataset.create( + display_name=_TEST_DISPLAY_NAME, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_text_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.TextDataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + ) + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_text_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME + ) + + my_dataset = datasets.TextDataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_text_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.TextDataset.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_text_mock.assert_called_once_with(name=_TEST_NAME) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_TEXT, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + +class TestVideoDataset: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_dataset_video(self, get_dataset_video_mock): + aiplatform.init(project=_TEST_PROJECT) + datasets.VideoDataset(dataset_name=_TEST_NAME) + get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_tabular_mock") + def test_init_dataset_non_video(self): + aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): + datasets.VideoDataset(dataset_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_dataset_video_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_dataset(self, create_dataset_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME + ) + + my_dataset = datasets.VideoDataset.create( + display_name=_TEST_DISPLAY_NAME, sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + @pytest.mark.usefixtures("get_dataset_video_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_and_import_dataset( + self, create_dataset_mock, import_data_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.VideoDataset.create( + display_name=_TEST_DISPLAY_NAME, + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + ) + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + @pytest.mark.usefixtures("get_dataset_video_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_import_data(self, import_data_mock, sync): + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.VideoDataset(dataset_name=_TEST_NAME) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_then_import( + self, create_dataset_mock, import_data_mock, get_dataset_video_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT) + + my_dataset = datasets.VideoDataset.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=[_TEST_SOURCE_URI_GCS], + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + sync=sync, + ) + + if not sync: + my_dataset.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=_TEST_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_VIDEO, + metadata=_TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + create_dataset_mock.assert_called_once_with( + parent=_TEST_PARENT, + dataset=expected_dataset, + metadata=_TEST_REQUEST_METADATA, + ) + + get_dataset_video_mock.assert_called_once_with(name=_TEST_NAME) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[_TEST_SOURCE_URI_GCS]), + import_schema_uri=_TEST_IMPORT_SCHEMA_URI_VIDEO, + ) + + import_data_mock.assert_called_once_with( + name=_TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = _TEST_NAME + assert my_dataset._gca_resource == expected_dataset diff --git a/tests/unit/aiplatform/test_end_to_end.py b/tests/unit/aiplatform/test_end_to_end.py new file mode 100644 index 0000000000..69c5517a69 --- /dev/null +++ b/tests/unit/aiplatform/test_end_to_end.py @@ -0,0 +1,462 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from importlib import reload + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +import test_datasets +from test_datasets import create_dataset_mock # noqa: F401 +from test_datasets import get_dataset_mock # noqa: F401 +from test_datasets import import_data_mock # noqa: F401 + +import test_endpoints +from test_endpoints import create_endpoint_mock # noqa: F401 +from test_endpoints import get_endpoint_mock # noqa: F401 +from test_endpoints import predict_client_predict_mock # noqa: F401 + +from test_models import deploy_model_mock # noqa: F401 + +import test_training_jobs +from test_training_jobs import mock_model_service_get # noqa: F401 +from test_training_jobs import mock_pipeline_service_create # noqa: F401 +from test_training_jobs import mock_pipeline_service_get # noqa: F401 +from test_training_jobs import ( # noqa: F401 + mock_pipeline_service_create_and_get_with_fail, +) +from test_training_jobs import mock_python_package_to_gcs # noqa: F401 + +from google.protobuf import json_format +from google.protobuf import struct_pb2 + +# dataset_encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + + +class TestEndToEnd: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures( + "get_dataset_mock", + "create_endpoint_mock", + "get_endpoint_mock", + "deploy_model_mock", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_dataset_create_to_model_predict( + self, + create_dataset_mock, # noqa: F811 + import_data_mock, # noqa: F811 + predict_client_predict_mock, # noqa: F811 + mock_python_package_to_gcs, # noqa: F811 + mock_pipeline_service_create, # noqa: F811 + mock_model_service_get, # noqa: F811 + mock_pipeline_service_get, # noqa: F811 + sync, + ): + + aiplatform.init( + project=test_datasets._TEST_PROJECT, + staging_bucket=test_training_jobs._TEST_BUCKET_NAME, + credentials=test_training_jobs._TEST_CREDENTIALS, + ) + + my_dataset = aiplatform.ImageDataset.create( + display_name=test_datasets._TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_dataset.import_data( + gcs_source=test_datasets._TEST_SOURCE_URI_GCS, + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + job = aiplatform.CustomTrainingJob( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + dataset=my_dataset, + base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR, + args=test_training_jobs._TEST_RUN_ARGS, + replica_count=1, + machine_type=test_training_jobs._TEST_MACHINE_TYPE, + accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE, + accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT, + model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + training_fraction_split=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + created_endpoint = models.Endpoint.create( + display_name=test_endpoints._TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + my_endpoint = model_from_job.deploy( + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, sync=sync + ) + + endpoint_deploy_return = created_endpoint.deploy(model_from_job, sync=sync) + + assert endpoint_deploy_return is None + + if not sync: + my_endpoint.wait() + created_endpoint.wait() + + test_prediction = created_endpoint.predict( + instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], parameters={"param": 3.0} + ) + + true_prediction = models.Prediction( + predictions=test_endpoints._TEST_PREDICTION, + deployed_model_id=test_endpoints._TEST_ID, + ) + + assert true_prediction == test_prediction + predict_client_predict_mock.assert_called_once_with( + endpoint=test_endpoints._TEST_ENDPOINT_NAME, + instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], + parameters={"param": 3.0}, + ) + + expected_dataset = gca_dataset.Dataset( + display_name=test_datasets._TEST_DISPLAY_NAME, + metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=test_datasets._TEST_PARENT, + dataset=expected_dataset, + metadata=test_datasets._TEST_REQUEST_METADATA, + ) + + import_data_mock.assert_called_once_with( + name=test_datasets._TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = test_datasets._TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME, + project=test_training_jobs._TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = test_training_jobs._TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": test_training_jobs._TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": test_training_jobs._TEST_MACHINE_TYPE, + "acceleratorType": test_training_jobs._TEST_ACCELERATOR_TYPE, + "acceleratorCount": test_training_jobs._TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=my_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": { + "output_uri_prefix": test_training_jobs._TEST_BASE_OUTPUT_DIR + }, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with( + name=test_training_jobs._TEST_MODEL_NAME + ) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "get_dataset_mock", + "create_endpoint_mock", + "get_endpoint_mock", + "deploy_model_mock", + ) + def test_dataset_create_to_model_predict_with_pipeline_fail( + self, + create_dataset_mock, # noqa: F811 + import_data_mock, # noqa: F811 + mock_python_package_to_gcs, # noqa: F811 + mock_pipeline_service_create_and_get_with_fail, # noqa: F811 + mock_model_service_get, # noqa: F811 + ): + + sync = False + + aiplatform.init( + project=test_datasets._TEST_PROJECT, + staging_bucket=test_training_jobs._TEST_BUCKET_NAME, + credentials=test_training_jobs._TEST_CREDENTIALS, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + + my_dataset = aiplatform.ImageDataset.create( + display_name=test_datasets._TEST_DISPLAY_NAME, sync=sync, + ) + + my_dataset.import_data( + gcs_source=test_datasets._TEST_SOURCE_URI_GCS, + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + sync=sync, + ) + + job = aiplatform.CustomTrainingJob( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + created_endpoint = models.Endpoint.create( + display_name=test_endpoints._TEST_DISPLAY_NAME, sync=sync, + ) + + model_from_job = job.run( + dataset=my_dataset, + base_output_dir=test_training_jobs._TEST_BASE_OUTPUT_DIR, + args=test_training_jobs._TEST_RUN_ARGS, + replica_count=1, + machine_type=test_training_jobs._TEST_MACHINE_TYPE, + accelerator_type=test_training_jobs._TEST_ACCELERATOR_TYPE, + accelerator_count=test_training_jobs._TEST_ACCELERATOR_COUNT, + model_display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + training_fraction_split=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + my_endpoint = model_from_job.deploy(sync=sync) + my_endpoint.wait() + + with pytest.raises(RuntimeError): + endpoint_deploy_return = created_endpoint.deploy(model_from_job, sync=sync) + assert endpoint_deploy_return is None + created_endpoint.wait() + + expected_dataset = gca_dataset.Dataset( + display_name=test_datasets._TEST_DISPLAY_NAME, + metadata_schema_uri=test_datasets._TEST_METADATA_SCHEMA_URI_NONTABULAR, + metadata=test_datasets._TEST_NONTABULAR_DATASET_METADATA, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + expected_import_config = gca_dataset.ImportDataConfig( + gcs_source=gca_io.GcsSource(uris=[test_datasets._TEST_SOURCE_URI_GCS]), + import_schema_uri=test_datasets._TEST_IMPORT_SCHEMA_URI, + data_item_labels=test_datasets._TEST_DATA_LABEL_ITEMS, + ) + + create_dataset_mock.assert_called_once_with( + parent=test_datasets._TEST_PARENT, + dataset=expected_dataset, + metadata=test_datasets._TEST_REQUEST_METADATA, + ) + + import_data_mock.assert_called_once_with( + name=test_datasets._TEST_NAME, import_configs=[expected_import_config] + ) + + expected_dataset.name = test_datasets._TEST_NAME + assert my_dataset._gca_resource == expected_dataset + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=test_training_jobs._TEST_BUCKET_NAME, + project=test_training_jobs._TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = test_training_jobs._TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": test_training_jobs._TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": test_training_jobs._TEST_MACHINE_TYPE, + "acceleratorType": test_training_jobs._TEST_ACCELERATOR_TYPE, + "acceleratorCount": test_training_jobs._TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": test_training_jobs._TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [test_training_jobs._TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=test_training_jobs._TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=test_training_jobs._TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=test_training_jobs._TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=test_training_jobs._TEST_SERVING_CONTAINER_IMAGE, + predict_route=test_training_jobs._TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=test_training_jobs._TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=test_training_jobs._TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=my_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=test_training_jobs._TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=test_training_jobs._TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": { + "output_uri_prefix": test_training_jobs._TEST_BASE_OUTPUT_DIR + }, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create_and_get_with_fail[0].assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert ( + job._gca_resource + is mock_pipeline_service_create_and_get_with_fail[1].return_value + ) + + mock_model_service_get.assert_not_called() + + with pytest.raises(RuntimeError): + job.get_model() + + assert job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py new file mode 100644 index 0000000000..ea74c89e5e --- /dev/null +++ b/tests/unit/aiplatform/test_endpoints.py @@ -0,0 +1,1079 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from unittest import mock +from importlib import reload +from datetime import datetime, timedelta + +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + client as prediction_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.types import ( + endpoint as gca_endpoint_v1beta1, + machine_resources as gca_machine_resources_v1beta1, + prediction_service as gca_prediction_service_v1beta1, + endpoint_service as gca_endpoint_service_v1beta1, +) + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client, +) +from google.cloud.aiplatform_v1.services.prediction_service import ( + client as prediction_service_client, +) +from google.cloud.aiplatform_v1.types import ( + endpoint as gca_endpoint, + model as gca_model, + machine_resources as gca_machine_resources, + prediction_service as gca_prediction_service, + endpoint_service as gca_endpoint_service, + encryption_spec as gca_encryption_spec, +) + +_TEST_PROJECT = "test-project" +_TEST_PROJECT_2 = "test-project-2" +_TEST_LOCATION = "us-central1" +_TEST_LOCATION_2 = "europe-west4" + +_TEST_ENDPOINT_NAME = "test-endpoint" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_DISPLAY_NAME_2 = "test-display-name-2" +_TEST_ID = "1028944691210842416" +_TEST_ID_2 = "4366591682456584192" +_TEST_DESCRIPTION = "test-description" + +_TEST_ENDPOINT_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}" +) +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ID}" +) + +_TEST_MODEL_ID = "1028944691210842416" +_TEST_PREDICTION = [[1.0, 2.0, 3.0], [3.0, 3.0, 1.0]] +_TEST_INSTANCES = [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]] +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) + +_TEST_DEPLOYED_MODELS = [ + gca_endpoint.DeployedModel(id=_TEST_ID, display_name=_TEST_DISPLAY_NAME), + gca_endpoint.DeployedModel(id=_TEST_ID_2, display_name=_TEST_DISPLAY_NAME_2), +] + +_TEST_MACHINE_TYPE = "n1-standard-32" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 + +_TEST_EXPLANATIONS = [ + gca_prediction_service_v1beta1.explanation.Explanation(attributions=[]) +] + +_TEST_ATTRIBUTIONS = [ + gca_prediction_service_v1beta1.explanation.Attribution( + baseline_output_value=1.0, + instance_output_value=2.0, + feature_attributions=3.0, + output_index=[1, 2, 3], + output_display_name="abc", + approximation_error=6.0, + output_name="xyz", + ) +] + +_TEST_EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( + inputs={ + "features": aiplatform.explain.ExplanationMetadata.InputMetadata( + { + "input_tensor_name": "dense_input", + "encoding": "BAG_OF_FEATURES", + "modality": "numeric", + "index_feature_mapping": ["abc", "def", "ghj"], + } + ) + }, + outputs={ + "medv": aiplatform.explain.ExplanationMetadata.OutputMetadata( + {"output_tensor_name": "dense_2"} + ) + }, +) +_TEST_EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( + {"sampled_shapley_attribution": {"path_count": 10}} +) + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + + +_TEST_ENDPOINT_LIST = [ + gca_endpoint.Endpoint( + display_name="aac", create_time=datetime.now() - timedelta(minutes=15) + ), + gca_endpoint.Endpoint( + display_name="aab", create_time=datetime.now() - timedelta(minutes=5) + ), + gca_endpoint.Endpoint( + display_name="aaa", create_time=datetime.now() - timedelta(minutes=10) + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY_CREATE_TIME = "create_time desc" +_TEST_LIST_ORDER_BY_DISPLAY_NAME = "display_name" + + +@pytest.fixture +def get_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + yield get_endpoint_mock + + +@pytest.fixture +def get_endpoint_with_models_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + deployed_models=_TEST_DEPLOYED_MODELS, + ) + yield get_endpoint_mock + + +@pytest.fixture +def get_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_DISPLAY_NAME, name=_TEST_MODEL_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def create_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "create_endpoint" + ) as create_endpoint_mock: + create_endpoint_lro_mock = mock.Mock(ga_operation.Operation) + create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint( + name=_TEST_ENDPOINT_NAME, display_name=_TEST_DISPLAY_NAME + ) + create_endpoint_mock.return_value = create_endpoint_lro_mock + yield create_endpoint_mock + + +@pytest.fixture +def deploy_model_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint.DeployedModel( + model=_TEST_MODEL_NAME, display_name=_TEST_DISPLAY_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def deploy_model_with_explanations_mock(): + with mock.patch.object( + endpoint_service_client_v1beta1.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint_v1beta1.DeployedModel( + model=_TEST_MODEL_NAME, display_name=_TEST_DISPLAY_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service_v1beta1.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def undeploy_model_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "undeploy_model" + ) as undeploy_model_mock: + undeploy_model_lro_mock = mock.Mock(ga_operation.Operation) + undeploy_model_lro_mock.result.return_value = ( + gca_endpoint_service.UndeployModelResponse() + ) + undeploy_model_mock.return_value = undeploy_model_lro_mock + yield undeploy_model_mock + + +@pytest.fixture +def delete_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "delete_endpoint" + ) as delete_endpoint_mock: + delete_endpoint_lro_mock = mock.Mock(ga_operation.Operation) + delete_endpoint_lro_mock.result.return_value = ( + gca_endpoint_service.DeleteEndpointRequest() + ) + delete_endpoint_mock.return_value = delete_endpoint_lro_mock + yield delete_endpoint_mock + + +@pytest.fixture +def sdk_private_undeploy_mock(): + """Mocks the high-level Endpoint._undeploy() SDK private method""" + with mock.patch.object(aiplatform.Endpoint, "_undeploy") as sdk_undeploy_mock: + sdk_undeploy_mock.return_value = None + yield sdk_undeploy_mock + + +@pytest.fixture +def sdk_undeploy_all_mock(): + """Mocks the high-level Endpoint.undeploy_all() SDK method""" + with mock.patch.object( + aiplatform.Endpoint, "undeploy_all" + ) as sdk_undeploy_all_mock: + sdk_undeploy_all_mock.return_value = None + yield sdk_undeploy_all_mock + + +@pytest.fixture +def list_endpoints_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "list_endpoints" + ) as list_endpoints_mock: + list_endpoints_mock.return_value = _TEST_ENDPOINT_LIST + yield list_endpoints_mock + + +@pytest.fixture +def create_client_mock(): + with mock.patch.object( + initializer.global_config, "create_client", autospec=True, + ) as create_client_mock: + create_client_mock.return_value = mock.Mock( + spec=endpoint_service_client.EndpointServiceClient + ) + yield create_client_mock + + +@pytest.fixture +def predict_client_predict_mock(): + with mock.patch.object( + prediction_service_client.PredictionServiceClient, "predict" + ) as predict_mock: + predict_mock.return_value = gca_prediction_service.PredictResponse( + deployed_model_id=_TEST_MODEL_ID + ) + predict_mock.return_value.predictions.extend(_TEST_PREDICTION) + yield predict_mock + + +@pytest.fixture +def predict_client_explain_mock(): + with mock.patch.object( + prediction_service_client_v1beta1.PredictionServiceClient, "explain" + ) as predict_mock: + predict_mock.return_value = gca_prediction_service_v1beta1.ExplainResponse( + deployed_model_id=_TEST_MODEL_ID, + ) + predict_mock.return_value.predictions.extend(_TEST_PREDICTION) + predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS) + predict_mock.return_value.explanations[0].attributions.extend( + _TEST_ATTRIBUTIONS + ) + yield predict_mock + + +class TestEndpoint: + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_constructor(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + models.Endpoint(_TEST_ENDPOINT_NAME) + create_client_mock.assert_has_calls( + [ + mock.call( + client_class=utils.EndpointClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION, + prediction_client=False, + ), + mock.call( + client_class=utils.PredictionClientWithOverride, + credentials=None, + location_override=_TEST_LOCATION, + prediction_client=True, + ), + ] + ) + + def test_constructor_with_endpoint_id(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(_TEST_ID) + get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) + + def test_constructor_with_endpoint_name(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(_TEST_ENDPOINT_NAME) + get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME) + + def test_constructor_with_custom_project(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2) + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) + + def test_constructor_with_custom_location(self, get_endpoint_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2) + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) + get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name) + + def test_constructor_with_custom_credentials(self, create_client_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = auth_credentials.AnonymousCredentials() + + models.Endpoint(_TEST_ENDPOINT_NAME, credentials=creds) + create_client_mock.assert_has_calls( + [ + mock.call( + client_class=utils.EndpointClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=False, + ), + mock.call( + client_class=utils.PredictionClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=True, + ), + ] + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_init_aiplatform_with_encryption_key_name_and_create_endpoint( + self, create_endpoint_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + my_endpoint = models.Endpoint.create(display_name=_TEST_DISPLAY_NAME, sync=sync) + + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC + ) + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, endpoint=expected_endpoint, metadata=(), + ) + + expected_endpoint.name = _TEST_ENDPOINT_NAME + assert my_endpoint._gca_resource == expected_endpoint + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create(self, create_endpoint_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_endpoint = models.Endpoint.create( + display_name=_TEST_DISPLAY_NAME, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, encryption_spec=_TEST_ENCRYPTION_SPEC + ) + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, endpoint=expected_endpoint, metadata=(), + ) + + expected_endpoint.name = _TEST_ENDPOINT_NAME + assert my_endpoint._gca_resource == expected_endpoint + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_with_description(self, create_endpoint_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_endpoint = models.Endpoint.create( + display_name=_TEST_DISPLAY_NAME, description=_TEST_DESCRIPTION, sync=sync + ) + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, description=_TEST_DESCRIPTION, + ) + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, endpoint=expected_endpoint, metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(test_model, sync=sync) + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_display_name(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, deployed_model_display_name=_TEST_DISPLAY_NAME, sync=sync + ) + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=_TEST_DISPLAY_NAME, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_80(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=80, sync=sync) + + if not sync: + test_endpoint.wait() + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_120(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=120, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_negative(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=-18, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_min_replica(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, min_replica_count=-1, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_max_replica(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_raise_error_traffic_split(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_split={"a": 99}, sync=sync) + + @pytest.mark.usefixtures("get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_traffic_percent(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 100}, + ) + + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, traffic_percentage=70, sync=sync) + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"model1": 30, "0": 70}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_traffic_split(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 100}, + ) + + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, traffic_split={"model1": 30, "0": 70}, sync=sync + ) + + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"model1": 30, "0": 70}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy( + model=test_model, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources( + machine_spec=expected_machine_spec, + min_replica_count=1, + max_replica_count=1, + ) + expected_deployed_model = gca_endpoint_v1beta1.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + ) + deploy_model_with_explanations_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_min_replica_count(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, min_replica_count=2, sync=sync) + + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=2, max_replica_count=2, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_max_replica_count(self, deploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync) + if not sync: + test_endpoint.wait() + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=2, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.parametrize( + "model1, model2, model3, percent", + [ + (100, None, None, 70), + (50, 50, None, 70), + (40, 60, None, 75), + (40, 60, None, 88), + (88, 12, None, 36), + (11, 89, None, 18), + (1, 99, None, 80), + (1, 2, 97, 68), + (99, 1, 0, 22), + (0, 0, 100, 18), + (7, 87, 6, 46), + ], + ) + def test_allocate_traffic(self, model1, model2, model3, percent): + old_split = {} + if model1 is not None: + old_split["model1"] = model1 + if model2 is not None: + old_split["model2"] = model2 + if model3 is not None: + old_split["model3"] = model3 + + new_split = models.Endpoint._allocate_traffic(old_split, percent) + new_split_sum = 0 + for model in new_split: + new_split_sum += new_split[model] + + assert new_split_sum == 100 + assert new_split["0"] == percent + + @pytest.mark.parametrize( + "model1, model2, model3, deployed_model", + [ + (100, None, None, "model1"), + (50, 50, None, "model1"), + (40, 60, None, "model2"), + (40, 60, None, "model1"), + (88, 12, None, "model1"), + (11, 89, None, "model1"), + (1, 99, None, "model2"), + (1, 2, 97, "model1"), + (99, 1, 0, "model2"), + (0, 0, 100, "model3"), + (7, 87, 6, "model2"), + ], + ) + def test_unallocate_traffic(self, model1, model2, model3, deployed_model): + old_split = {} + if model1 is not None: + old_split["model1"] = model1 + if model2 is not None: + old_split["model2"] = model2 + if model3 is not None: + old_split["model3"] = model3 + + new_split = models.Endpoint._unallocate_traffic(old_split, deployed_model) + new_split_sum = 0 + for model in new_split: + new_split_sum += new_split[model] + + assert new_split_sum == 100 or new_split_sum == 0 + assert new_split[deployed_model] == 0 + + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy(self, undeploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 100}, + ) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + assert dict(test_endpoint._gca_resource.traffic_split) == {"model1": 100} + test_endpoint.undeploy("model1", sync=sync) + if not sync: + test_endpoint.wait() + undeploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model_id="model1", + traffic_split={}, + # traffic_split={"model1": 0}, + metadata=(), + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_with_traffic_split(self, undeploy_model_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + name=_TEST_ENDPOINT_NAME, + traffic_split={"model1": 40, "model2": 60}, + ) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_endpoint.undeploy( + deployed_model_id="model1", + traffic_split={"model1": 0, "model2": 100}, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + undeploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model_id="model1", + traffic_split={"model2": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_raise_error_traffic_split_total(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_endpoint.undeploy( + deployed_model_id="model1", traffic_split={"model2": 99}, sync=sync + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_raise_error_undeployed_model_traffic(self, sync): + with pytest.raises(ValueError): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_endpoint.undeploy( + deployed_model_id="model1", + traffic_split={"model1": 50, "model2": 50}, + sync=sync, + ) + + def test_predict(self, get_endpoint_mock, predict_client_predict_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_endpoint = models.Endpoint(_TEST_ID) + test_prediction = test_endpoint.predict( + instances=_TEST_INSTANCES, parameters={"param": 3.0} + ) + + true_prediction = models.Prediction( + predictions=_TEST_PREDICTION, deployed_model_id=_TEST_ID + ) + + assert true_prediction == test_prediction + predict_client_predict_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + ) + + def test_explain(self, get_endpoint_mock, predict_client_explain_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_endpoint = models.Endpoint(_TEST_ID) + test_prediction = test_endpoint.explain( + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + ) + expected_explanations = _TEST_EXPLANATIONS + expected_explanations[0].attributions.extend(_TEST_ATTRIBUTIONS) + + expected_prediction = models.Prediction( + predictions=_TEST_PREDICTION, + deployed_model_id=_TEST_ID, + explanations=expected_explanations, + ) + + assert expected_prediction == test_prediction + predict_client_explain_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + ) + + def test_list_models(self, get_endpoint_with_models_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + my_models = ept.list_models() + + assert my_models == _TEST_DEPLOYED_MODELS + + @pytest.mark.usefixtures("get_endpoint_with_models_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_undeploy_all(self, sdk_private_undeploy_mock, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + ept.undeploy_all(sync=sync) + + if not sync: + ept.wait() + + # undeploy_all() results in an undeploy() call for each deployed_model + sdk_private_undeploy_mock.assert_has_calls( + [ + mock.call(deployed_model_id=deployed_model.id, sync=sync) + for deployed_model in _TEST_DEPLOYED_MODELS + ], + any_order=True, + ) + + def test_list_endpoint_order_by_time(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in descending order of create_time""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_CREATE_TIME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ep_list[0].create_time > ep_list[1].create_time > ep_list[2].create_time + + def test_list_endpoint_order_by_display_name(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in order of display_name""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_DISPLAY_NAME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ( + ep_list[0].display_name < ep_list[1].display_name < ep_list[2].display_name + ) + + @pytest.mark.usefixtures("get_endpoint_with_models_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_endpoint_without_force( + self, sdk_undeploy_all_mock, delete_endpoint_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + ept.delete(sync=sync) + + if not sync: + ept.wait() + + # undeploy_all() should not be called unless force is set to True + sdk_undeploy_all_mock.assert_not_called() + + delete_endpoint_mock.assert_called_once_with(name=_TEST_ENDPOINT_NAME) + + @pytest.mark.usefixtures("get_endpoint_with_models_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_endpoint_with_force( + self, sdk_undeploy_all_mock, delete_endpoint_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + ept = aiplatform.Endpoint(_TEST_ID) + ept.delete(force=True, sync=sync) + + if not sync: + ept.wait() + + # undeploy_all() should be called if force is set to True + sdk_undeploy_all_mock.assert_called_once() + + delete_endpoint_mock.assert_called_once_with(name=_TEST_ENDPOINT_NAME) diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py new file mode 100644 index 0000000000..1d97ad2e9a --- /dev/null +++ b/tests/unit/aiplatform/test_initializer.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import os +import pytest +from unittest import mock + +import google.auth +from google.auth import credentials + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import constants +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) + +_TEST_PROJECT = "test-project" +_TEST_PROJECT_2 = "test-project-2" +_TEST_LOCATION = "us-central1" +_TEST_LOCATION_2 = "europe-west4" +_TEST_INVALID_LOCATION = "test-invalid-location" +_TEST_EXPERIMENT = "test-experiment" +_TEST_STAGING_BUCKET = "test-bucket" + + +class TestInit: + def setup_method(self): + importlib.reload(initializer) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_project_sets_project(self): + initializer.global_config.init(project=_TEST_PROJECT) + assert initializer.global_config.project == _TEST_PROJECT + + def test_not_init_project_gets_default_project(self, monkeypatch): + def mock_auth_default(): + return None, _TEST_PROJECT + + monkeypatch.setattr(google.auth, "default", mock_auth_default) + assert initializer.global_config.project == _TEST_PROJECT + + def test_init_location_sets_location(self): + initializer.global_config.init(location=_TEST_LOCATION) + assert initializer.global_config.location == _TEST_LOCATION + + def test_not_init_location_gets_default_location(self): + assert initializer.global_config.location == constants.DEFAULT_REGION + + def test_init_location_with_invalid_location_raises(self): + with pytest.raises(ValueError): + initializer.global_config.init(location=_TEST_INVALID_LOCATION) + + def test_init_experiment_sets_experiment(self): + initializer.global_config.init(experiment=_TEST_EXPERIMENT) + assert initializer.global_config.experiment == _TEST_EXPERIMENT + + def test_init_staging_bucket_sets_staging_bucket(self): + initializer.global_config.init(staging_bucket=_TEST_STAGING_BUCKET) + assert initializer.global_config.staging_bucket == _TEST_STAGING_BUCKET + + def test_init_credentials_sets_credentials(self): + creds = credentials.AnonymousCredentials() + initializer.global_config.init(credentials=creds) + assert initializer.global_config.credentials is creds + + def test_common_location_path_returns_parent(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + true_resource_parent = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + assert true_resource_parent == initializer.global_config.common_location_path() + + def test_common_location_path_overrides(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + true_resource_parent = ( + f"projects/{_TEST_PROJECT_2}/locations/{_TEST_LOCATION_2}" + ) + assert true_resource_parent == initializer.global_config.common_location_path( + project=_TEST_PROJECT_2, location=_TEST_LOCATION_2 + ) + + def test_create_client_returns_client(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + client = initializer.global_config.create_client( + client_class=utils.ModelClientWithOverride + ) + assert client._client_class is model_service_client.ModelServiceClient + assert isinstance(client, utils.ModelClientWithOverride) + assert ( + client._transport._host == f"{_TEST_LOCATION}-{constants.API_BASE_PATH}:443" + ) + + def test_create_client_overrides(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = credentials.AnonymousCredentials() + client = initializer.global_config.create_client( + client_class=utils.ModelClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION_2, + prediction_client=True, + ) + assert isinstance(client, utils.ModelClientWithOverride) + assert ( + client._transport._host + == f"{_TEST_LOCATION_2}-{constants.API_BASE_PATH}:443" + ) + assert client._transport._credentials == creds + + def test_create_client_user_agent(self): + initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + client = initializer.global_config.create_client( + client_class=utils.ModelClientWithOverride + ) + + for wrapped_method in client._transport._wrapped_methods.values(): + # wrapped_method._metadata looks like: + # [('x-goog-api-client', 'model-builder/0.3.1 gl-python/3.7.6 grpc/1.30.0 gax/1.22.2 gapic/0.3.1')] + user_agent = wrapped_method._metadata[0][1] + assert user_agent.startswith("model-builder/") + + @pytest.mark.parametrize( + "init_location, location_override, expected_endpoint", + [ + ("us-central1", None, "us-central1-aiplatform.googleapis.com"), + ("us-central1", "europe-west4", "europe-west4-aiplatform.googleapis.com",), + ("asia-east1", None, "asia-east1-aiplatform.googleapis.com"), + ], + ) + def test_get_client_options( + self, init_location: str, location_override: str, expected_endpoint: str, + ): + initializer.global_config.init(location=init_location) + + assert ( + initializer.global_config.get_client_options( + location_override=location_override + ).api_endpoint + == expected_endpoint + ) + + +class TestThreadPool: + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize( + "cpu_count, expected", [(4, 20), (32, 32), (None, 4), (2, 10)] + ) + def test_max_workers(self, cpu_count, expected): + with mock.patch.object(os, "cpu_count") as cpu_count_mock: + cpu_count_mock.return_value = cpu_count + importlib.reload(initializer) + assert initializer.global_pool._max_workers == expected diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py new file mode 100644 index 0000000000..acc7317ebb --- /dev/null +++ b/tests/unit/aiplatform/test_jobs.py @@ -0,0 +1,639 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from unittest import mock +from importlib import reload +from unittest.mock import patch + +from google.cloud import storage +from google.cloud import bigquery + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) + +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job_v1beta1, + explanation as gca_explanation_v1beta1, + io as gca_io_v1beta1, + machine_resources as gca_machine_resources_v1beta1, +) + +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client + +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, + io as gca_io, + job_state as gca_job_state, +) + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_ID = "1028944691210842416" +_TEST_ALT_ID = "8834795523125638878" +_TEST_DISPLAY_NAME = "my_job_1234" +_TEST_BQ_DATASET_ID = "bqDatasetId" +_TEST_BQ_JOB_ID = "123459876" +_TEST_BQ_MAX_RESULTS = 100 +_TEST_GCS_BUCKET_NAME = "my-bucket" + +_TEST_BQ_PATH = f"bq://projectId.{_TEST_BQ_DATASET_ID}" +_TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}" +_TEST_GCS_JSONL_SOURCE_URI = f"{_TEST_GCS_BUCKET_PATH}/bp_input_config.jsonl" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ALT_ID}" +) +_TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_ID}" +_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME = "test-batch-prediction-job" + +_TEST_BATCH_PREDICTION_GCS_SOURCE = "gs://example-bucket/folder/instance.jsonl" +_TEST_BATCH_PREDICTION_GCS_SOURCE_LIST = [ + "gs://example-bucket/folder/instance1.jsonl", + "gs://example-bucket/folder/instance2.jsonl", +] +_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX = "gs://example-bucket/folder/output" +_TEST_BATCH_PREDICTION_BQ_PREFIX = "ucaip-sample-tests" +_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL = ( + f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}" +) + +_TEST_JOB_STATE_SUCCESS = gca_job_state.JobState(4) +_TEST_JOB_STATE_RUNNING = gca_job_state.JobState(3) +_TEST_JOB_STATE_PENDING = gca_job_state.JobState(2) + +_TEST_GCS_INPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_GCS_JSONL_SOURCE_URI]), +) +_TEST_GCS_OUTPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + predictions_format="jsonl", + gcs_destination=gca_io.GcsDestination(output_uri_prefix=_TEST_GCS_BUCKET_PATH), +) + +_TEST_BQ_INPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io.BigQuerySource(input_uri=_TEST_BQ_PATH), +) +_TEST_BQ_OUTPUT_CONFIG = gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + predictions_format="bigquery", + bigquery_destination=gca_io.BigQueryDestination(output_uri=_TEST_BQ_PATH), +) + +_TEST_GCS_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_BUCKET_NAME +) +_TEST_BQ_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo( + bigquery_output_dataset=_TEST_BQ_PATH +) + +_TEST_EMPTY_OUTPUT_INFO = gca_batch_prediction_job.BatchPredictionJob.OutputInfo() + +_TEST_GCS_BLOBS = [ + storage.Blob(name="some/path/prediction.jsonl", bucket=_TEST_GCS_BUCKET_NAME) +] + +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 +_TEST_STARTING_REPLICA_COUNT = 2 +_TEST_MAX_REPLICA_COUNT = 12 + +_TEST_LABEL = {"team": "experimentation", "trial_id": "x435"} + +_TEST_EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( + inputs={ + "features": { + "input_tensor_name": "dense_input", + "encoding": "BAG_OF_FEATURES", + "modality": "numeric", + "index_feature_mapping": ["abc", "def", "ghj"], + } + }, + outputs={"medv": {"output_tensor_name": "dense_2"}}, +) +_TEST_EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( + {"sampled_shapley_attribution": {"path_count": 10}} +) + +_TEST_JOB_GET_METHOD_NAME = "get_fake_job" +_TEST_JOB_LIST_METHOD_NAME = "list_fake_job" +_TEST_JOB_CANCEL_METHOD_NAME = "cancel_fake_job" +_TEST_JOB_DELETE_METHOD_NAME = "delete_fake_job" +_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/fakeJobs/{_TEST_ID}" + +# TODO(b/171333554): Move reusable test fixtures to conftest.py file + + +@pytest.fixture +def fake_job_getter_mock(): + with patch.object( + job_service_client.JobServiceClient, _TEST_JOB_GET_METHOD_NAME, create=True + ) as fake_job_getter_mock: + fake_job_getter_mock.return_value = {} + yield fake_job_getter_mock + + +@pytest.fixture +def fake_job_cancel_mock(): + with patch.object( + job_service_client.JobServiceClient, _TEST_JOB_CANCEL_METHOD_NAME, create=True + ) as fake_job_cancel_mock: + yield fake_job_cancel_mock + + +class TestJob: + class FakeJob(jobs._Job): + _job_type = "fake-job" + _resource_noun = "fakeJobs" + _getter_method = _TEST_JOB_GET_METHOD_NAME + _list_method = _TEST_JOB_LIST_METHOD_NAME + _cancel_method = _TEST_JOB_CANCEL_METHOD_NAME + _delete_method = _TEST_JOB_DELETE_METHOD_NAME + resource_name = _TEST_JOB_RESOURCE_NAME + + def setup_method(self): + reload(initializer) + reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + # Unit Tests + def test_init_job_class(self): + """ + Raises TypeError since abstract property '_getter_method' is not set, + the _Job class should only be instantiated through a child class. + """ + with pytest.raises(TypeError): + jobs._Job(job_name=_TEST_BATCH_PREDICTION_JOB_NAME) + + @pytest.mark.usefixtures("fake_job_getter_mock") + def test_cancel_mock_job(self, fake_job_cancel_mock): + """Create a fake `_Job` child class, and ensure the high-level cancel method works""" + fake_job = self.FakeJob(job_name=_TEST_JOB_RESOURCE_NAME) + fake_job.cancel() + + fake_job_cancel_mock.assert_called_once_with(name=_TEST_JOB_RESOURCE_NAME) + + +@pytest.fixture +def get_batch_prediction_job_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.side_effect = [ + gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_RUNNING, + ), + gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ), + gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ), + ] + yield get_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + create_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield create_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_with_explanations_mock(): + with mock.patch.object( + job_service_client_v1beta1.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + create_batch_prediction_job_mock.return_value = gca_batch_prediction_job_v1beta1.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield create_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_gcs_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_GCS_OUTPUT_CONFIG, + output_info=_TEST_GCS_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_bq_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_BQ_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_empty_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_EMPTY_OUTPUT_INFO, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def get_batch_prediction_job_running_bq_output_mock(): + with patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.return_value = gca_batch_prediction_job.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=_TEST_GCS_INPUT_CONFIG, + output_config=_TEST_BQ_OUTPUT_CONFIG, + output_info=_TEST_BQ_OUTPUT_INFO, + state=_TEST_JOB_STATE_RUNNING, + ) + yield get_batch_prediction_job_mock + + +@pytest.fixture +def storage_list_blobs_mock(): + with patch.object(storage.Client, "list_blobs") as list_blobs_mock: + list_blobs_mock.return_value = _TEST_GCS_BLOBS + yield list_blobs_mock + + +@pytest.fixture +def bq_list_rows_mock(): + with patch.object(bigquery.Client, "list_rows") as list_rows_mock: + list_rows_mock.return_value = mock.Mock(bigquery.table.RowIterator) + yield list_rows_mock + + +class TestBatchPredictionJob: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_batch_prediction_job(self, get_batch_prediction_job_mock): + jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + get_batch_prediction_job_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + def test_batch_prediction_job_status(self, get_batch_prediction_job_mock): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + # get_batch_prediction() is called again here + bp_job_state = bp.state + + assert get_batch_prediction_job_mock.call_count == 2 + assert bp_job_state == _TEST_JOB_STATE_SUCCESS + + get_batch_prediction_job_mock.assert_called_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_gcs_output_mock") + def test_batch_prediction_iter_dirs_gcs(self, storage_list_blobs_mock): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + blobs = bp.iter_outputs() + + storage_list_blobs_mock.assert_called_once_with( + _TEST_GCS_OUTPUT_INFO.gcs_output_directory, prefix=None + ) + + assert blobs == _TEST_GCS_BLOBS + + @pytest.mark.usefixtures("get_batch_prediction_job_bq_output_mock") + def test_batch_prediction_iter_dirs_bq(self, bq_list_rows_mock): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + bp.iter_outputs() + + bq_list_rows_mock.assert_called_once_with( + table=f"{_TEST_BQ_DATASET_ID}.predictions", max_results=_TEST_BQ_MAX_RESULTS + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock") + def test_batch_prediction_iter_dirs_while_running(self): + """ + Raises RuntimeError since outputs cannot be read while BatchPredictionJob is still running + """ + with pytest.raises(RuntimeError): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + bp.iter_outputs() + + @pytest.mark.usefixtures("get_batch_prediction_job_empty_output_mock") + def test_batch_prediction_iter_dirs_invalid_output_info(self): + """ + Raises NotImplementedError since the BatchPredictionJob's output_info + contains no output GCS directory or BQ dataset. + """ + with pytest.raises(NotImplementedError): + bp = jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + bp.iter_outputs() + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_and_dest( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + batch_prediction_job = jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_bq_dest( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + batch_prediction_job = jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL + ), + predictions_format="bigquery", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_with_all_args( + self, create_batch_prediction_job_with_explanations_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = auth_credentials.AnonymousCredentials() + + batch_prediction_job = jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + predictions_format="csv", + model_parameters={}, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + generate_explanation=True, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + labels=_TEST_LABEL, + credentials=creds, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_v1beta1.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), + ), + output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_v1beta1.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="csv", + ), + dedicated_resources=gca_machine_resources_v1beta1.BatchDedicatedResources( + machine_spec=gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ), + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + ), + generate_explanation=True, + explanation_spec=gca_explanation_v1beta1.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + labels=_TEST_LABEL, + ) + + create_batch_prediction_job_with_explanations_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_no_source(self, create_batch_prediction_job_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call without source + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call with two sources + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_no_destination(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call without destination + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + ) + + assert e.match(regexp=r"destination") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_wrong_instance_format(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + instances_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted instances format") + + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_wrong_prediction_format(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + predictions_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted prediction format") diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py new file mode 100644 index 0000000000..47b000d189 --- /dev/null +++ b/tests/unit/aiplatform/test_models.py @@ -0,0 +1,1130 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +from concurrent import futures +import pytest +from unittest import mock + +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import models +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + client as endpoint_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.job_service import ( + client as job_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job_v1beta1, + env_var as gca_env_var_v1beta1, + explanation as gca_explanation_v1beta1, + io as gca_io_v1beta1, + model as gca_model_v1beta1, + endpoint as gca_endpoint_v1beta1, + machine_resources as gca_machine_resources_v1beta1, + model_service as gca_model_service_v1beta1, + endpoint_service as gca_endpoint_service_v1beta1, + encryption_spec as gca_encryption_spec_v1beta1, +) + +from google.cloud.aiplatform_v1.services.endpoint_service import ( + client as endpoint_service_client, +) +from google.cloud.aiplatform_v1.services.job_service import client as job_service_client +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, + io as gca_io, + job_state as gca_job_state, + model as gca_model, + endpoint as gca_endpoint, + machine_resources as gca_machine_resources, + model_service as gca_model_service, + endpoint_service as gca_endpoint_service, + encryption_spec as gca_encryption_spec, +) + + +from test_endpoints import create_endpoint_mock # noqa: F401 + +_TEST_PROJECT = "test-project" +_TEST_PROJECT_2 = "test-project-2" +_TEST_LOCATION = "us-central1" +_TEST_LOCATION_2 = "europe-west4" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_MODEL_NAME = "test-model" +_TEST_ARTIFACT_URI = "gs://test/artifact/uri" +_TEST_SERVING_CONTAINER_IMAGE = "gcr.io/test-serving/container:image" +_TEST_SERVING_CONTAINER_PREDICTION_ROUTE = "predict" +_TEST_SERVING_CONTAINER_HEALTH_ROUTE = "metadata" +_TEST_DESCRIPTION = "test description" +_TEST_SERVING_CONTAINER_COMMAND = ["python3", "run_my_model.py"] +_TEST_SERVING_CONTAINER_ARGS = ["--test", "arg"] +_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES = { + "learning_rate": 0.01, + "loss_fn": "mse", +} +_TEST_SERVING_CONTAINER_PORTS = [8888, 10000] +_TEST_ID = "1028944691210842416" +_TEST_LABEL = {"team": "experimentation", "trial_id": "x435"} + +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 +_TEST_STARTING_REPLICA_COUNT = 2 +_TEST_MAX_REPLICA_COUNT = 12 + +_TEST_BATCH_PREDICTION_GCS_SOURCE = "gs://example-bucket/folder/instance.jsonl" +_TEST_BATCH_PREDICTION_GCS_SOURCE_LIST = [ + "gs://example-bucket/folder/instance1.jsonl", + "gs://example-bucket/folder/instance2.jsonl", +] +_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX = "gs://example-bucket/folder/output" +_TEST_BATCH_PREDICTION_BQ_PREFIX = "ucaip-sample-tests" +_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL = ( + f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}" +) +_TEST_BATCH_PREDICTION_DISPLAY_NAME = "test-batch-prediction-job" +_TEST_BATCH_PREDICTION_JOB_NAME = job_service_client.JobServiceClient.batch_prediction_job_path( + project=_TEST_PROJECT, location=_TEST_LOCATION, batch_prediction_job=_TEST_ID +) + +_TEST_INSTANCE_SCHEMA_URI = "gs://test/schema/instance.yaml" +_TEST_PARAMETERS_SCHEMA_URI = "gs://test/schema/parameters.yaml" +_TEST_PREDICTION_SCHEMA_URI = "gs://test/schema/predictions.yaml" + +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) + +_TEST_EXPLANATION_METADATA = aiplatform.explain.ExplanationMetadata( + inputs={ + "features": { + "input_tensor_name": "dense_input", + "encoding": "BAG_OF_FEATURES", + "modality": "numeric", + "index_feature_mapping": ["abc", "def", "ghj"], + } + }, + outputs={"medv": {"output_tensor_name": "dense_2"}}, +) +_TEST_EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters( + {"sampled_shapley_attribution": {"path_count": 10}} +) + +# CMEK encryption +_TEST_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) +_TEST_ENCRYPTION_SPEC_V1BETA1 = gca_encryption_spec_v1beta1.EncryptionSpec( + kms_key_name=_TEST_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_RESOURCE_NAME = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID +) +_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID +) +_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID +) + +_TEST_OUTPUT_DIR = "gs://my-output-bucket" + + +@pytest.fixture +def get_endpoint_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "get_endpoint" + ) as get_endpoint_mock: + test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ) + get_endpoint_mock.return_value = gca_endpoint.Endpoint( + display_name=_TEST_MODEL_NAME, name=test_endpoint_resource_name, + ) + yield get_endpoint_mock + + +@pytest.fixture +def get_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_explanations_mock(): + with mock.patch.object( + model_service_client_v1beta1.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model_v1beta1.Model( + display_name=_TEST_MODEL_NAME, name=_TEST_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_custom_location_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION, + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_custom_project_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name=_TEST_MODEL_NAME, + name=_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT, + ) + yield get_model_mock + + +@pytest.fixture +def upload_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_explanations_mock(): + with mock.patch.object( + model_service_client_v1beta1.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service_v1beta1.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_custom_project_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def upload_model_with_custom_location_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "upload_model" + ) as upload_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.UploadModelResponse( + model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION + ) + upload_model_mock.return_value = mock_lro + yield upload_model_mock + + +@pytest.fixture +def delete_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "delete_model" + ) as delete_model_mock: + delete_model_lro_mock = mock.Mock(ga_operation.Operation) + delete_model_lro_mock.result.return_value = ( + gca_model_service.DeleteModelRequest() + ) + delete_model_mock.return_value = delete_model_lro_mock + yield delete_model_mock + + +@pytest.fixture +def deploy_model_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint.DeployedModel( + model=_TEST_MODEL_RESOURCE_NAME, display_name=_TEST_MODEL_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def deploy_model_with_explanations_mock(): + with mock.patch.object( + endpoint_service_client_v1beta1.EndpointServiceClient, "deploy_model" + ) as deploy_model_mock: + deployed_model = gca_endpoint_v1beta1.DeployedModel( + model=_TEST_MODEL_RESOURCE_NAME, display_name=_TEST_MODEL_NAME, + ) + deploy_model_lro_mock = mock.Mock(ga_operation.Operation) + deploy_model_lro_mock.result.return_value = gca_endpoint_service_v1beta1.DeployModelResponse( + deployed_model=deployed_model, + ) + deploy_model_mock.return_value = deploy_model_lro_mock + yield deploy_model_mock + + +@pytest.fixture +def get_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + batch_prediction_mock = mock.Mock( + spec=gca_batch_prediction_job.BatchPredictionJob + ) + batch_prediction_mock.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED + batch_prediction_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME + get_batch_prediction_job_mock.return_value = batch_prediction_mock + yield get_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + batch_prediction_job_mock = mock.Mock( + spec=gca_batch_prediction_job.BatchPredictionJob + ) + batch_prediction_job_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME + create_batch_prediction_job_mock.return_value = batch_prediction_job_mock + yield create_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_with_explanations_mock(): + with mock.patch.object( + job_service_client_v1beta1.JobServiceClient, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + batch_prediction_job_mock = mock.Mock( + spec=gca_batch_prediction_job_v1beta1.BatchPredictionJob + ) + batch_prediction_job_mock.name = _TEST_BATCH_PREDICTION_JOB_NAME + create_batch_prediction_job_mock.return_value = batch_prediction_job_mock + yield create_batch_prediction_job_mock + + +@pytest.fixture +def create_client_mock(): + with mock.patch.object( + initializer.global_config, "create_client" + ) as create_client_mock: + api_client_mock = mock.Mock(spec=model_service_client.ModelServiceClient) + create_client_mock.return_value = api_client_mock + yield create_client_mock + + +class TestModel: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_constructor_creates_client(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + models.Model(_TEST_ID) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION, + prediction_client=False, + ) + + def test_constructor_create_client_with_custom_location(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + models.Model(_TEST_ID, location=_TEST_LOCATION_2) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION_2, + prediction_client=False, + ) + + def test_constructor_creates_client_with_custom_credentials( + self, create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + creds = auth_credentials.AnonymousCredentials() + models.Model(_TEST_ID, credentials=creds) + create_client_mock.assert_called_once_with( + client_class=utils.ModelClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=False, + ) + + def test_constructor_gets_model(self, get_model_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Model(_TEST_ID) + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) + + def test_constructor_gets_model_with_custom_project(self, get_model_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Model(_TEST_ID, project=_TEST_PROJECT_2) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + get_model_mock.assert_called_once_with(name=test_model_resource_name) + + def test_constructor_gets_model_with_custom_location(self, get_model_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + models.Model(_TEST_ID, location=_TEST_LOCATION_2) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) + get_model_mock.assert_called_once_with(name=test_model_resource_name) + + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model( + self, upload_model_mock, get_model_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, container_spec=container_spec, + ) + + upload_model_mock.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ) + + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) + + def test_upload_raises_with_impartial_explanation_spec(self): + + with pytest.raises(ValueError) as e: + models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS + # Missing the required explanations_metadata field + ) + + assert e.match(regexp=r"`explanation_parameters` should be specified or None.") + + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model_with_all_args( + self, upload_model_with_explanations_mock, get_model_mock, sync + ): + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + instance_schema_uri=_TEST_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_PREDICTION_SCHEMA_URI, + description=_TEST_DESCRIPTION, + serving_container_command=_TEST_SERVING_CONTAINER_COMMAND, + serving_container_args=_TEST_SERVING_CONTAINER_ARGS, + serving_container_environment_variables=_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + serving_container_ports=_TEST_SERVING_CONTAINER_PORTS, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + sync=sync, + ) + + if not sync: + my_model.wait() + + env = [ + gca_env_var_v1beta1.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model_v1beta1.Port(container_port=port) + for port in _TEST_SERVING_CONTAINER_PORTS + ] + + container_spec = gca_model_v1beta1.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_SERVING_CONTAINER_COMMAND, + args=_TEST_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + managed_model = gca_model_v1beta1.Model( + display_name=_TEST_MODEL_NAME, + description=_TEST_DESCRIPTION, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + predict_schemata=gca_model_v1beta1.PredictSchemata( + instance_schema_uri=_TEST_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_PREDICTION_SCHEMA_URI, + ), + explanation_spec=gca_model_v1beta1.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + ) + + upload_model_with_explanations_mock.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + model=managed_model, + ) + get_model_mock.assert_called_once_with(name=_TEST_MODEL_RESOURCE_NAME) + + @pytest.mark.usefixtures("get_model_with_custom_project_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model_with_custom_project( + self, + upload_model_with_custom_project_mock, + get_model_with_custom_project_mock, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + project=_TEST_PROJECT_2, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + ) + + upload_model_with_custom_project_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT_2}/locations/{_TEST_LOCATION}", + model=managed_model, + ) + + get_model_with_custom_project_mock.assert_called_once_with( + name=test_model_resource_name + ) + + @pytest.mark.usefixtures("get_model_with_custom_location_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_upload_uploads_and_gets_model_with_custom_location( + self, + upload_model_with_custom_location_mock, + get_model_with_custom_location_mock, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID + ) + + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + location=_TEST_LOCATION_2, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + ) + + upload_model_with_custom_location_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION_2}", + model=managed_model, + ) + + get_model_with_custom_location_mock.assert_called_once_with( + name=test_model_resource_name + ) + + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy(self, deploy_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_endpoint = models.Endpoint(_TEST_ID) + + assert test_model.deploy(test_endpoint, sync=sync,) == test_endpoint + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint(self, deploy_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_endpoint = test_model.deploy(sync=sync) + + if not sync: + test_endpoint.wait() + + automatic_resources = gca_machine_resources.AutomaticResources( + min_replica_count=1, max_replica_count=1, + ) + deployed_model = gca_endpoint.DeployedModel( + automatic_resources=automatic_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_with_explanations( + self, deploy_model_with_explanations_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + sync=sync, + ) + + if not sync: + test_endpoint.wait() + + expected_machine_spec = gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ) + expected_dedicated_resources = gca_machine_resources_v1beta1.DedicatedResources( + machine_spec=expected_machine_spec, min_replica_count=1, max_replica_count=1 + ) + expected_deployed_model = gca_endpoint_v1beta1.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + explanation_spec=gca_endpoint_v1beta1.explanation.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + ) + deploy_model_with_explanations_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + ) + + @pytest.mark.usefixtures( + "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" + ) + def test_deploy_raises_with_impartial_explanation_spec(self): + + test_model = models.Model(_TEST_ID) + + with pytest.raises(ValueError) as e: + test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + explanation_metadata=_TEST_EXPLANATION_METADATA, + # Missing required `explanation_parameters` argument + ) + + assert e.match(regexp=r"`explanation_parameters` should be specified or None.") + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_and_dest( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + ) + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + encryption_spec=_TEST_ENCRYPTION_SPEC, + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_and_dest( + self, create_batch_prediction_job_mock, sync + ): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_gcs_source_bq_dest( + self, create_batch_prediction_job_mock, sync + ): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io.GcsSource(uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]), + ), + output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BATCH_PREDICTION_BQ_DEST_PREFIX_WITH_PROTOCOL + ), + predictions_format="bigquery", + ), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_with_all_args( + self, create_batch_prediction_job_with_explanations_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + creds = auth_credentials.AnonymousCredentials() + + # Make SDK batch_predict method call passing all arguments + batch_prediction_job = test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + predictions_format="csv", + model_parameters={}, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + generate_explanation=True, + explanation_metadata=_TEST_EXPLANATION_METADATA, + explanation_parameters=_TEST_EXPLANATION_PARAMETERS, + labels=_TEST_LABEL, + credentials=creds, + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, + sync=sync, + ) + + if not sync: + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job_v1beta1.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + model=model_service_client_v1beta1.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_ID + ), + input_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_v1beta1.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), + ), + output_config=gca_batch_prediction_job_v1beta1.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_v1beta1.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="csv", + ), + dedicated_resources=gca_machine_resources_v1beta1.BatchDedicatedResources( + machine_spec=gca_machine_resources_v1beta1.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ), + starting_replica_count=_TEST_STARTING_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + ), + generate_explanation=True, + explanation_spec=gca_explanation_v1beta1.ExplanationSpec( + metadata=_TEST_EXPLANATION_METADATA, + parameters=_TEST_EXPLANATION_PARAMETERS, + ), + labels=_TEST_LABEL, + encryption_spec=_TEST_ENCRYPTION_SPEC_V1BETA1, + ) + + create_batch_prediction_job_with_explanations_mock.assert_called_once_with( + parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + batch_prediction_job=expected_gapic_batch_prediction_job, + ) + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_no_source(self, create_batch_prediction_job_mock): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call without source + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call with two sources + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX, + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"source") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_no_destination(self): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call without destination + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + ) + + assert e.match(regexp=r"destination") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_wrong_instance_format(self): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + instances_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted instances format") + + @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") + def test_batch_predict_wrong_prediction_format(self): + + test_model = models.Model(_TEST_ID) + + # Make SDK batch_predict method call + with pytest.raises(ValueError) as e: + test_model.batch_predict( + job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + predictions_format="wrong", + bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + ) + + assert e.match(regexp=r"accepted prediction format") + + @pytest.mark.usefixtures("get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_delete_model(self, delete_model_mock, sync): + + test_model = models.Model(_TEST_ID) + test_model.delete(sync=sync) + + if not sync: + test_model.wait() + + delete_model_mock.assert_called_once_with(name=test_model.resource_name) + + @pytest.mark.usefixtures("get_model_mock") + def test_print_model(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + assert ( + repr(test_model) + == f"{object.__repr__(test_model)} \nresource name: {test_model.resource_name}" + ) + + @pytest.mark.usefixtures("get_model_mock") + def test_print_model_if_waiting(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + test_model._gca_resource = None + test_model._latest_future = futures.Future() + assert ( + repr(test_model) + == f"{object.__repr__(test_model)} is waiting for upstream dependencies to complete." + ) + + @pytest.mark.usefixtures("get_model_mock") + def test_print_model_if_exception(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) + test_model._gca_resource = None + mock_exception = Exception("mock exception") + test_model._exception = mock_exception + assert ( + repr(test_model) + == f"{object.__repr__(test_model)} failed with {str(mock_exception)}" + ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py new file mode 100644 index 0000000000..b5520a5f4c --- /dev/null +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -0,0 +1,3865 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from distutils import core +import functools +import importlib +import pathlib +import pytest +import subprocess +import shutil +import sys +import tarfile +import tempfile +from unittest import mock +from unittest.mock import patch + +from google.auth import credentials as auth_credentials + +from google.cloud import aiplatform + +from google.cloud.aiplatform import datasets +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import schema +from google.cloud.aiplatform import training_jobs + +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client, +) +from google.cloud.aiplatform_v1.services.pipeline_service import ( + client as pipeline_service_client, +) + +from google.cloud.aiplatform_v1.types import ( + dataset as gca_dataset, + encryption_spec as gca_encryption_spec, + env_var as gca_env_var, + io as gca_io, + model as gca_model, + pipeline_state as gca_pipeline_state, + training_pipeline as gca_training_pipeline, +) + +from google.cloud import storage +from google.protobuf import json_format +from google.protobuf import struct_pb2 + + +_TEST_BUCKET_NAME = "test-bucket" +_TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder" +_TEST_GCS_PATH = f"{_TEST_BUCKET_NAME}/{_TEST_GCS_PATH_WITHOUT_BUCKET}" +_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/" +_TEST_LOCAL_SCRIPT_FILE_NAME = "____test____script.py" +_TEST_LOCAL_SCRIPT_FILE_PATH = f"path/to/{_TEST_LOCAL_SCRIPT_FILE_NAME}" +_TEST_PYTHON_SOURCE = """ +print('hello world') +""" +_TEST_REQUIREMENTS = ["pandas", "numpy", "tensorflow"] + +_TEST_DATASET_DISPLAY_NAME = "test-dataset-display-name" +_TEST_DATASET_NAME = "test-dataset-name" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_METADATA_SCHEMA_URI_TABULAR = schema.dataset.metadata.tabular +_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image" +_TEST_TRAINING_CONTAINER_CMD = ["python3", "task.py"] +_TEST_SERVING_CONTAINER_IMAGE = "gcr.io/test-serving/container:image" +_TEST_SERVING_CONTAINER_PREDICTION_ROUTE = "predict" +_TEST_SERVING_CONTAINER_HEALTH_ROUTE = "metadata" + +_TEST_METADATA_SCHEMA_URI_NONTABULAR = schema.dataset.metadata.image +_TEST_ANNOTATION_SCHEMA_URI = schema.dataset.annotation.image.classification + +_TEST_BASE_OUTPUT_DIR = "gs://test-base-output-dir" +_TEST_BIGQUERY_DESTINATION = "bq://test-project" +_TEST_RUN_ARGS = ["-v", 0.1, "--test=arg"] +_TEST_REPLICA_COUNT = 1 +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_K80" +_TEST_INVALID_ACCELERATOR_TYPE = "NVIDIA_DOES_NOT_EXIST" +_TEST_ACCELERATOR_COUNT = 1 +_TEST_MODEL_DISPLAY_NAME = "model-display-name" +_TEST_DEFAULT_TRAINING_FRACTION_SPLIT = 0.8 +_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT = 0.1 +_TEST_DEFAULT_TEST_FRACTION_SPLIT = 0.1 +_TEST_TRAINING_FRACTION_SPLIT = 0.6 +_TEST_VALIDATION_FRACTION_SPLIT = 0.2 +_TEST_TEST_FRACTION_SPLIT = 0.2 +_TEST_PREDEFINED_SPLIT_COLUMN_NAME = "split" + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_ID = "12345" +_TEST_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipelines/{_TEST_ID}" +) +_TEST_ALT_PROJECT = "test-project-alt" +_TEST_ALT_LOCATION = "europe-west4" + +_TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml" +_TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml" +_TEST_MODEL_PREDICTION_SCHEMA_URI = "prediction_schema_uri.yaml" +_TEST_MODEL_SERVING_CONTAINER_COMMAND = ["test_command"] +_TEST_MODEL_SERVING_CONTAINER_ARGS = ["test_args"] +_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES = { + "learning_rate": 0.01, + "loss_fn": "mse", +} +_TEST_MODEL_SERVING_CONTAINER_PORTS = [8888, 10000] +_TEST_MODEL_DESCRIPTION = "test description" + +_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test/ouput/python/trainer.tar.gz" +_TEST_PYTHON_MODULE_NAME = "aiplatform.task" + +_TEST_MODEL_NAME = "projects/my-project/locations/us-central1/models/12345" + +_TEST_PIPELINE_RESOURCE_NAME = ( + "projects/my-project/locations/us-central1/trainingPipeline/12345" +) +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" +_TEST_DEFAULT_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_PIPELINE_ENCRYPTION_KEY_NAME = "key_pipeline" +_TEST_PIPELINE_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME +) + +_TEST_MODEL_ENCRYPTION_KEY_NAME = "key_model" +_TEST_MODEL_ENCRYPTION_SPEC = gca_encryption_spec.EncryptionSpec( + kms_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME +) + + +def local_copy_method(path): + shutil.copy(path, ".") + return pathlib.Path(path).name + + +@pytest.fixture +def get_training_job_custom_mock(): + with patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as get_training_job_custom_mock: + get_training_job_custom_mock.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + training_task_definition=schema.training_job.definition.custom_task, + ) + + yield get_training_job_custom_mock + + +@pytest.fixture +def get_training_job_custom_mock_no_model_to_upload(): + with patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as get_training_job_custom_mock: + get_training_job_custom_mock.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=None, + training_task_definition=schema.training_job.definition.custom_task, + ) + + yield get_training_job_custom_mock + + +@pytest.fixture +def get_training_job_tabular_mock(): + with patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as get_training_job_tabular_mock: + get_training_job_tabular_mock.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + training_task_definition=schema.training_job.definition.automl_tabular, + ) + + yield get_training_job_tabular_mock + + +@pytest.fixture +def mock_client_bucket(): + with patch.object(storage.Client, "bucket") as mock_client_bucket: + + def blob_side_effect(name, mock_blob, bucket): + mock_blob.name = name + mock_blob.bucket = bucket + return mock_blob + + MockBucket = mock.Mock(autospec=storage.Bucket) + MockBucket.name = _TEST_BUCKET_NAME + MockBlob = mock.Mock(autospec=storage.Blob) + MockBucket.blob.side_effect = functools.partial( + blob_side_effect, mock_blob=MockBlob, bucket=MockBucket + ) + mock_client_bucket.return_value = MockBucket + + yield mock_client_bucket, MockBlob + + +class TestTrainingScriptPythonPackagerHelpers: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def test_timestamp_copy_to_gcs_calls_gcs_client_with_bucket( + self, mock_client_bucket + ): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + ) + + local_script_file_name = pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + blob_arg = mock_client_bucket.return_value.blob.call_args[0][0] + assert blob_arg.startswith("aiplatform-") + assert blob_arg.endswith(_TEST_LOCAL_SCRIPT_FILE_NAME) + + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + assert gcs_path.endswith(local_script_file_name) + assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}/aiplatform-") + + def test_timestamp_copy_to_gcs_calls_gcs_client_with_gcs_path( + self, mock_client_bucket + ): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_GCS_PATH_WITH_TRAILING_SLASH, + project=_TEST_PROJECT, + ) + + local_script_file_name = pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + blob_arg = mock_client_bucket.return_value.blob.call_args[0][0] + assert blob_arg.startswith(f"{_TEST_GCS_PATH_WITHOUT_BUCKET}/aiplatform-") + assert blob_arg.endswith(f"{_TEST_LOCAL_SCRIPT_FILE_NAME}") + + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + + assert gcs_path.startswith(f"gs://{_TEST_GCS_PATH}/aiplatform-") + assert gcs_path.endswith(local_script_file_name) + + def test_timestamp_copy_to_gcs_calls_gcs_client_with_trailing_slash( + self, mock_client_bucket + ): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_GCS_PATH, + project=_TEST_PROJECT, + ) + + local_script_file_name = pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + blob_arg = mock_client_bucket.return_value.blob.call_args[0][0] + assert blob_arg.startswith(f"{_TEST_GCS_PATH_WITHOUT_BUCKET}/aiplatform-") + assert blob_arg.endswith(_TEST_LOCAL_SCRIPT_FILE_NAME) + + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + + assert gcs_path.startswith(f"gs://{_TEST_GCS_PATH}/aiplatform-") + assert gcs_path.endswith(local_script_file_name) + + def test_timestamp_copy_to_gcs_calls_gcs_client(self, mock_client_bucket): + + mock_client_bucket, mock_blob = mock_client_bucket + + gcs_path = training_jobs._timestamped_copy_to_gcs( + local_file_path=_TEST_LOCAL_SCRIPT_FILE_PATH, + gcs_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + ) + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + mock_blob.upload_from_filename.assert_called_once_with( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + assert gcs_path.endswith(pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).name) + assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}") + + def test_get_python_executable_raises_if_None(self): + with patch.object(sys, "executable", new=None): + with pytest.raises(EnvironmentError): + training_jobs._get_python_executable() + + def test_get_python_executable_returns_python_executable(self): + assert "python" in training_jobs._get_python_executable().lower() + + +class TestTrainingScriptPythonPackager: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + with open(_TEST_LOCAL_SCRIPT_FILE_NAME, "w") as fp: + fp.write(_TEST_PYTHON_SOURCE) + + def teardown_method(self): + pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_NAME).unlink() + python_package_file = f"{training_jobs._TrainingScriptPythonPackager._ROOT_MODULE}-{training_jobs._TrainingScriptPythonPackager._SETUP_PY_VERSION}.tar.gz" + if pathlib.Path(python_package_file).is_file(): + pathlib.Path(python_package_file).unlink() + subprocess.check_output( + [ + "pip3", + "uninstall", + "-y", + training_jobs._TrainingScriptPythonPackager._ROOT_MODULE, + ] + ) + + def test_packager_creates_and_copies_python_package(self): + tsp = training_jobs._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_NAME) + tsp.package_and_copy(copy_method=local_copy_method) + assert pathlib.Path( + f"{tsp._ROOT_MODULE}-{tsp._SETUP_PY_VERSION}.tar.gz" + ).is_file() + + def test_created_package_module_is_installable_and_can_be_run(self): + tsp = training_jobs._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_NAME) + source_dist_path = tsp.package_and_copy(copy_method=local_copy_method) + subprocess.check_output(["pip3", "install", source_dist_path]) + module_output = subprocess.check_output( + [training_jobs._get_python_executable(), "-m", tsp.module_name] + ) + assert "hello world" in module_output.decode() + + def test_requirements_are_in_package(self): + tsp = training_jobs._TrainingScriptPythonPackager( + _TEST_LOCAL_SCRIPT_FILE_NAME, requirements=_TEST_REQUIREMENTS + ) + source_dist_path = tsp.package_and_copy(copy_method=local_copy_method) + with tarfile.open(source_dist_path) as tf: + with tempfile.TemporaryDirectory() as tmpdirname: + setup_py_path = f"{training_jobs._TrainingScriptPythonPackager._ROOT_MODULE}-{training_jobs._TrainingScriptPythonPackager._SETUP_PY_VERSION}/setup.py" + tf.extract(setup_py_path, path=tmpdirname) + setup_py = core.run_setup( + pathlib.Path(tmpdirname, setup_py_path), stop_after="init" + ) + assert _TEST_REQUIREMENTS == setup_py.install_requires + + def test_packaging_fails_whith_RuntimeError(self): + with patch("subprocess.Popen") as mock_popen: + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 1 + mock_popen.return_value = mock_subprocess + tsp = training_jobs._TrainingScriptPythonPackager( + _TEST_LOCAL_SCRIPT_FILE_NAME + ) + with pytest.raises(RuntimeError): + tsp.package_and_copy(copy_method=local_copy_method) + + def test_package_and_copy_to_gcs_copies_to_gcs(self, mock_client_bucket): + mock_client_bucket, mock_blob = mock_client_bucket + + tsp = training_jobs._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_NAME) + + gcs_path = tsp.package_and_copy_to_gcs( + gcs_staging_dir=_TEST_BUCKET_NAME, project=_TEST_PROJECT + ) + + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() + + mock_blob.upload_from_filename.call_args[0][0].endswith( + "/trainer/dist/aiplatform_custom_trainer_script-0.1.tar.gz" + ) + + assert gcs_path.endswith("-aiplatform_custom_trainer_script-0.1.tar.gz") + assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}") + + +@pytest.fixture +def mock_pipeline_service_create(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME), + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_cancel(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "cancel_training_pipeline" + ) as mock_cancel_training_pipeline: + yield mock_cancel_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_with_no_model_to_upload(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ) + yield mock_create_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_get_with_no_model_to_upload(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ) + yield mock_get_training_pipeline + + +@pytest.fixture +def mock_pipeline_service_create_and_get_with_fail(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ) + + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.return_value = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED, + ) + + yield mock_create_training_pipeline, mock_get_training_pipeline + + +@pytest.fixture +def mock_model_service_get(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as mock_get_model: + mock_get_model.return_value = gca_model.Model(name=_TEST_MODEL_NAME) + yield mock_get_model + + +@pytest.fixture +def mock_python_package_to_gcs(): + with mock.patch.object( + training_jobs._TrainingScriptPythonPackager, "package_and_copy_to_gcs" + ) as mock_package_to_copy_gcs: + mock_package_to_copy_gcs.return_value = _TEST_OUTPUT_PYTHON_PACKAGE_PATH + yield mock_package_to_copy_gcs + + +@pytest.fixture +def mock_tabular_dataset(): + ds = mock.MagicMock(datasets.TabularDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +@pytest.fixture +def mock_nontabular_dataset(): + ds = mock.MagicMock(datasets.ImageDataset) + ds.name = _TEST_DATASET_NAME + ds._latest_future = None + ds._exception = None + ds._gca_resource = gca_dataset.Dataset( + display_name=_TEST_DATASET_DISPLAY_NAME, + metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR, + labels={}, + name=_TEST_DATASET_NAME, + metadata={}, + ) + return ds + + +class TestCustomTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + with open(_TEST_LOCAL_SCRIPT_FILE_NAME, "w") as fp: + fp.write(_TEST_PYTHON_SOURCE) + + def teardown_method(self): + pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_NAME).unlink() + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_bigquery_destination( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + bigquery_destination=_TEST_BIGQUERY_DESTINATION, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BIGQUERY_DESTINATION + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_invalid_accelerator_type_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_incomplete_model_info_raises_with_model_to_upload( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_no_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, container_spec=true_container_spec + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_returns_none_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + model = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + assert model is None + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_get_model_raises_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_python_package_to_gcs, + mock_tabular_dataset, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called( + self, mock_pipeline_service_create, mock_python_package_to_gcs + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state + + def test_run_raises_if_no_staging_bucket(self): + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(RuntimeError): + training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_distributed_training( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = [ + { + "replicaCount": 1, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + { + "replicaCount": 9, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + ] + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": true_worker_pool_spec, + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures("get_training_job_custom_mock") + def test_get_training_job(self, get_training_job_custom_mock): + aiplatform.init(project=_TEST_PROJECT) + job = training_jobs.CustomTrainingJob.get(resource_name=_TEST_NAME) + + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + assert isinstance(job, training_jobs.CustomTrainingJob) + + @pytest.mark.usefixtures("get_training_job_custom_mock") + def test_get_training_job_wrong_job_type(self, get_training_job_custom_mock): + aiplatform.init(project=_TEST_PROJECT) + + # The returned job is for a custom training task, + # but the calling type if of AutoMLImageTrainingJob. + # Hence, it should throw an error. + with pytest.raises(ValueError): + training_jobs.AutoMLImageTrainingJob.get(resource_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_training_job_custom_mock_no_model_to_upload") + def test_get_training_job_no_model_to_upload( + self, get_training_job_custom_mock_no_model_to_upload + ): + aiplatform.init(project=_TEST_PROJECT) + + job = training_jobs.CustomTrainingJob.get(resource_name=_TEST_NAME) + + with pytest.raises(RuntimeError): + job.get_model(sync=False) + + @pytest.mark.usefixtures("get_training_job_tabular_mock") + def test_get_training_job_tabular(self, get_training_job_tabular_mock): + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(ValueError): + training_jobs.CustomTrainingJob.get(resource_name=_TEST_NAME) + + @pytest.mark.usefixtures("get_training_job_custom_mock") + def test_get_training_job_with_id_only(self, get_training_job_custom_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get(resource_name=_TEST_ID) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_id_only_with_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_alt_project_and_location( + self, get_training_job_custom_mock + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION + ) + get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME) + + def test_get_training_job_with_project_and_alt_location(self): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises(RuntimeError): + training_jobs.CustomTrainingJob.get( + resource_name=_TEST_NAME, + project=_TEST_PROJECT, + location=_TEST_ALT_LOCATION, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_nontabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_nontabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_nontabular_dataset, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + mock_python_package_to_gcs.assert_called_once_with( + gcs_staging_dir=_TEST_BUCKET_NAME, + project=_TEST_PROJECT, + credentials=initializer.global_config.credentials, + ) + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": training_jobs._TrainingScriptPythonPackager.module_name, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_nontabular_dataset.name, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + def test_run_call_pipeline_service_create_with_nontabular_dataset_raises_if_annotation_schema_uri( + self, mock_nontabular_dataset, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + with pytest.raises(Exception): + job.run( + dataset=mock_nontabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + def test_cancel_training_job(self, mock_pipeline_service_cancel): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run() + job.cancel() + + mock_pipeline_service_cancel.assert_called_once_with( + name=_TEST_PIPELINE_RESOURCE_NAME + ) + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + def test_cancel_training_job_without_running(self, mock_pipeline_service_cancel): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError) as e: + job.cancel() + + assert e.match(regexp=r"TrainingJob has not been launched") + + +class TestCustomContainerTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_bigquery_destination( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + bigquery_destination=_TEST_BIGQUERY_DESTINATION, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BIGQUERY_DESTINATION + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_invalid_accelerator_type_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_incomplete_model_info_raises_with_model_to_upload( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_no_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, container_spec=true_container_spec + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_returns_none_if_no_model_to_upload( + self, + mock_pipeline_service_create_with_no_model_to_upload, + mock_pipeline_service_get_with_no_model_to_upload, + mock_tabular_dataset, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + model = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + assert model is None + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_get_model_raises_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_tabular_dataset, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state + + def test_run_raises_if_no_staging_bucket(self): + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(RuntimeError): + training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_distributed_training( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = [ + { + "replicaCount": 1, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + }, + { + "replicaCount": 9, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + }, + ] + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": true_worker_pool_spec, + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_nontabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_nontabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + model_from_job = job.run( + dataset=mock_nontabular_dataset, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "containerSpec": { + "imageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "command": _TEST_TRAINING_CONTAINER_CMD, + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_nontabular_dataset.name, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + def test_run_call_pipeline_service_create_with_nontabular_dataset_raises_if_annotation_schema_uri( + self, mock_nontabular_dataset, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomContainerTrainingJob( + display_name=_TEST_DISPLAY_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + command=_TEST_TRAINING_CONTAINER_CMD, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + ) + + with pytest.raises(Exception): + job.run( + dataset=mock_nontabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + ) + + +class Test_MachineSpec: + def test_machine_spec_return_spec_dict(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ) + + true_spec_dict = { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": _TEST_REPLICA_COUNT, + } + + assert test_spec.spec_dict == true_spec_dict + + def test_machine_spec_return_spec_dict_with_no_accelerator(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=0, + accelerator_type="ACCELERATOR_TYPE_UNSPECIFIED", + ) + + true_spec_dict = { + "machineSpec": {"machineType": _TEST_MACHINE_TYPE}, + "replicaCount": _TEST_REPLICA_COUNT, + } + + assert test_spec.spec_dict == true_spec_dict + + def test_machine_spec_spec_dict_raises_invalid_accelerator(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + ) + + with pytest.raises(ValueError): + test_spec.spec_dict + + def test_machine_spec_spec_dict_is_empty(self): + test_spec = training_jobs._MachineSpec( + replica_count=0, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + ) + + assert test_spec.is_empty + + def test_machine_spec_spec_dict_is_not_empty(self): + test_spec = training_jobs._MachineSpec( + replica_count=_TEST_REPLICA_COUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + ) + + assert not test_spec.is_empty + + +class Test_DistributedTrainingSpec: + def test_machine_spec_returns_pool_spec(self): + + spec = training_jobs._DistributedTrainingSpec( + chief_spec=training_jobs._MachineSpec( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + worker_spec=training_jobs._MachineSpec( + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + parameter_server_spec=training_jobs._MachineSpec( + replica_count=3, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + evaluator_spec=training_jobs._MachineSpec( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 10, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 3, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + ] + + assert spec.pool_specs == true_pool_spec + + def test_chief_worker_pool_returns_spec(self): + + chief_worker_spec = training_jobs._DistributedTrainingSpec.chief_worker_pool( + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 9, + }, + ] + + assert chief_worker_spec.pool_specs == true_pool_spec + + def test_chief_worker_pool_returns_just_chief(self): + + chief_worker_spec = training_jobs._DistributedTrainingSpec.chief_worker_pool( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + } + ] + + assert chief_worker_spec.pool_specs == true_pool_spec + + def test_machine_spec_raise_with_more_than_one_chief_replica(self): + + spec = training_jobs._DistributedTrainingSpec( + chief_spec=training_jobs._MachineSpec( + replica_count=2, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + ) + + with pytest.raises(ValueError): + spec.pool_specs + + def test_machine_spec_handles_missing_pools(self): + + spec = training_jobs._DistributedTrainingSpec( + chief_spec=training_jobs._MachineSpec( + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + worker_spec=training_jobs._MachineSpec(replica_count=0), + parameter_server_spec=training_jobs._MachineSpec( + replica_count=3, + machine_type=_TEST_MACHINE_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + accelerator_type=_TEST_ACCELERATOR_TYPE, + ), + evaluator_spec=training_jobs._MachineSpec(replica_count=0), + ) + + true_pool_spec = [ + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 1, + }, + {"machineSpec": {"machineType": "n1-standard-4"}, "replicaCount": 0}, + { + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "replicaCount": 3, + }, + ] + + assert spec.pool_specs == true_pool_spec + + +class TestCustomPythonPackageTrainingJob: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_bigquery_destination( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + training_encryption_spec_key_name=_TEST_PIPELINE_ENCRYPTION_KEY_NAME, + model_encryption_spec_key_name=_TEST_MODEL_ENCRYPTION_KEY_NAME, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + bigquery_destination=_TEST_BIGQUERY_DESTINATION, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_MODEL_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + predefined_split=gca_training_pipeline.PredefinedSplit( + key=_TEST_PREDEFINED_SPLIT_COLUMN_NAME + ), + dataset_id=mock_tabular_dataset.name, + bigquery_destination=gca_io.BigQueryDestination( + output_uri=_TEST_BIGQUERY_DESTINATION + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_PIPELINE_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_python_package_to_gcs", + "mock_model_service_get", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_called_twice_raises( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_invalid_accelerator_type_raises( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(ValueError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_INVALID_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_with_incomplete_model_info_raises_with_model_to_upload( + self, + mock_pipeline_service_create, + mock_python_package_to_gcs, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_no_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + model_from_job = job.run( + model_display_name=_TEST_MODEL_DISPLAY_NAME, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, container_spec=true_container_spec + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_returns_none_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + model = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + assert model is None + + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_no_model_to_upload", + "mock_pipeline_service_get_with_no_model_to_upload", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_get_model_raises_if_no_model_to_upload( + self, mock_tabular_dataset, sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_raises_if_pipeline_fails( + self, + mock_pipeline_service_create_and_get_with_fail, + mock_tabular_dataset, + sync, + ): + + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + with pytest.raises(RuntimeError): + job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + job.wait() + + with pytest.raises(RuntimeError): + job.get_model() + + def test_raises_before_run_is_called(self, mock_pipeline_service_create): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + with pytest.raises(RuntimeError): + job.get_model() + + with pytest.raises(RuntimeError): + job.has_failed + + with pytest.raises(RuntimeError): + job.state + + def test_run_raises_if_no_staging_bucket(self): + + aiplatform.init(project=_TEST_PROJECT) + + with pytest.raises(RuntimeError): + training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_distributed_training( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=10, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = [ + { + "replicaCount": 1, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + { + "replicaCount": 9, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + }, + ] + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": true_worker_pool_spec, + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_nontabular_dataset( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_python_package_to_gcs, + mock_nontabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_nontabular_dataset, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + + true_worker_pool_spec = { + "replicaCount": _TEST_REPLICA_COUNT, + "machineSpec": { + "machineType": _TEST_MACHINE_TYPE, + "acceleratorType": _TEST_ACCELERATOR_TYPE, + "acceleratorCount": _TEST_ACCELERATOR_COUNT, + }, + "pythonPackageSpec": { + "executorImageUri": _TEST_TRAINING_CONTAINER_IMAGE, + "pythonModule": _TEST_PYTHON_MODULE_NAME, + "packageUris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_DEFAULT_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_DEFAULT_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_DEFAULT_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_nontabular_dataset.name, + annotation_schema_uri=_TEST_ANNOTATION_SCHEMA_URI, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "workerPoolSpecs": [true_worker_pool_spec], + "baseOutputDirectory": {"output_uri_prefix": _TEST_BASE_OUTPUT_DIR}, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + ) + + assert job._gca_resource is mock_pipeline_service_get.return_value + + mock_model_service_get.assert_called_once_with(name=_TEST_MODEL_NAME) + + assert model_from_job._gca_resource is mock_model_service_get.return_value + + assert job.get_model()._gca_resource is mock_model_service_get.return_value + + assert not job.has_failed + + assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + + def test_run_call_pipeline_service_create_with_nontabular_dataset_raises_if_annotation_schema_uri( + self, mock_nontabular_dataset, + ): + aiplatform.init( + project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + with pytest.raises(Exception): + job.run( + dataset=mock_nontabular_dataset, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + replica_count=1, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + ) diff --git a/tests/unit/aiplatform/test_training_utils.py b/tests/unit/aiplatform/test_training_utils.py new file mode 100644 index 0000000000..1d4b839151 --- /dev/null +++ b/tests/unit/aiplatform/test_training_utils.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os +import pytest + +from google.cloud.aiplatform import training_utils +from unittest import mock + +_TEST_TRAINING_DATA_URI = "gs://training-data-uri" +_TEST_VALIDATION_DATA_URI = "gs://test-validation-data-uri" +_TEST_TEST_DATA_URI = "gs://test-data-uri" +_TEST_MODEL_DIR = "gs://test-model-dir" +_TEST_CHECKPOINT_DIR = "gs://test-checkpoint-dir" +_TEST_TENSORBOARD_LOG_DIR = "gs://test-tensorboard-log-dir" +_TEST_CLUSTER_SPEC = """{ + "cluster": { + "worker_pools":[ + { + "index":0, + "replicas":[ + "training-workerpool0-ab-0:2222" + ] + }, + { + "index":1, + "replicas":[ + "training-workerpool1-ab-0:2222", + "training-workerpool1-ab-1:2222" + ] + } + ] + }, + "environment": "cloud", + "task": { + "worker_pool_index":0, + "replica_index":0, + "trial":"TRIAL_ID" + } +}""" + + +class TestTrainingUtils: + @pytest.fixture + def mock_environment(self): + env_vars = { + "AIP_TRAINING_DATA_URI": _TEST_TRAINING_DATA_URI, + "AIP_VALIDATION_DATA_URI": _TEST_VALIDATION_DATA_URI, + "AIP_TEST_DATA_URI": _TEST_TEST_DATA_URI, + "AIP_MODEL_DIR": _TEST_MODEL_DIR, + "AIP_CHECKPOINT_DIR": _TEST_CHECKPOINT_DIR, + "AIP_TENSORBOARD_LOG_DIR": _TEST_TENSORBOARD_LOG_DIR, + "CLUSTER_SPEC": _TEST_CLUSTER_SPEC, + "TF_CONFIG": _TEST_CLUSTER_SPEC, + } + with mock.patch.dict(os.environ, env_vars): + yield + + @pytest.mark.usefixtures("mock_environment") + def test_training_data_uri(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.training_data_uri == _TEST_TRAINING_DATA_URI + + def test_training_data_uri_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.training_data_uri is None + + @pytest.mark.usefixtures("mock_environment") + def test_validation_data_uri(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.validation_data_uri == _TEST_VALIDATION_DATA_URI + + def test_validation_data_uri_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.validation_data_uri is None + + @pytest.mark.usefixtures("mock_environment") + def test_test_data_uri(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.test_data_uri == _TEST_TEST_DATA_URI + + def test_test_data_uri_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.test_data_uri is None + + @pytest.mark.usefixtures("mock_environment") + def test_model_dir(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.model_dir == _TEST_MODEL_DIR + + def test_model_dir_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.model_dir is None + + @pytest.mark.usefixtures("mock_environment") + def test_checkpoint_dir(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.checkpoint_dir == _TEST_CHECKPOINT_DIR + + def test_checkpoint_dir_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.checkpoint_dir is None + + @pytest.mark.usefixtures("mock_environment") + def test_tensorboard_log_dir(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tensorboard_log_dir == _TEST_TENSORBOARD_LOG_DIR + + def test_tensorboard_log_dir_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tensorboard_log_dir is None + + @pytest.mark.usefixtures("mock_environment") + def test_cluster_spec(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.cluster_spec == json.loads(_TEST_CLUSTER_SPEC) + + def test_cluster_spec_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.cluster_spec is None + + @pytest.mark.usefixtures("mock_environment") + def test_tf_config(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tf_config == json.loads(_TEST_CLUSTER_SPEC) + + def test_tf_config_none(self): + env_vars = training_utils.EnvironmentVariables() + assert env_vars.tf_config is None diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py new file mode 100644 index 0000000000..3032475069 --- /dev/null +++ b/tests/unit/aiplatform/test_utils.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from uuid import uuid4 +from random import choice +from random import randint +from string import ascii_letters + +from google.api_core import client_options +from google.api_core import gapic_v1 +from google.cloud import aiplatform +from google.cloud.aiplatform import compat +from google.cloud.aiplatform import utils + +from google.cloud.aiplatform_v1beta1.services.model_service import ( + client as model_service_client_v1beta1, +) +from google.cloud.aiplatform_v1.services.model_service import ( + client as model_service_client_v1, +) + +model_service_client_default = model_service_client_v1 + + +@pytest.mark.parametrize( + "resource_name, expected", + [ + ("projects/123456/locations/us-central1/datasets/987654", True), + ("projects/857392/locations/us-central1/trainingPipelines/347292", True), + ("projects/acme-co-proj-1/locations/us-central1/datasets/123456", True), + ("projects/acme-co-proj-1/locations/us-central1/datasets/abcdef", False), + ("project/123456/locations/us-central1/datasets/987654", False), + ("project//locations//datasets/987654", False), + ("locations/europe-west4/datasets/987654", False), + ("987654", False), + ], +) +def test_extract_fields_from_resource_name(resource_name: str, expected: bool): + # Given a resource name and expected validity, test extract_fields_from_resource_name() + assert expected == bool(utils.extract_fields_from_resource_name(resource_name)) + + +@pytest.fixture +def generated_resource_fields(): + generated_fields = utils.Fields( + project=str(uuid4()), + location=str(uuid4()), + resource="".join(choice(ascii_letters) for i in range(10)), # 10 random letters + id=str(randint(0, 100000)), + ) + + yield generated_fields + + +@pytest.fixture +def generated_resource_name(generated_resource_fields: utils.Fields): + name = ( + f"projects/{generated_resource_fields.project}/" + f"locations/{generated_resource_fields.location}" + f"/{generated_resource_fields.resource}/{generated_resource_fields.id}" + ) + + yield name + + +def test_extract_fields_from_resource_name_with_extracted_fields( + generated_resource_name: str, generated_resource_fields: utils.Fields +): + """Verify fields extracted from resource name match the original fields""" + + assert ( + utils.extract_fields_from_resource_name(resource_name=generated_resource_name) + == generated_resource_fields + ) + + +@pytest.mark.parametrize( + "resource_name, resource_noun, expected", + [ + # Expects pattern "projects/.../locations/.../datasets/..." + ("projects/123456/locations/us-central1/datasets/987654", "datasets", True), + # Expects pattern "projects/.../locations/.../batchPredictionJobs/..." + ( + "projects/857392/locations/us-central1/trainingPipelines/347292", + "batchPredictionJobs", + False, + ), + ], +) +def test_extract_fields_from_resource_name_with_resource_noun( + resource_name: str, resource_noun: str, expected: bool +): + assert ( + bool( + utils.extract_fields_from_resource_name( + resource_name=resource_name, resource_noun=resource_noun + ) + ) + == expected + ) + + +def test_invalid_region_raises_with_invalid_region(): + with pytest.raises(ValueError): + aiplatform.utils.validate_region(region="us-west4") + + +def test_invalid_region_does_not_raise_with_valid_region(): + aiplatform.utils.validate_region(region="us-central1") + + +@pytest.mark.parametrize( + "resource_noun, project, location, full_name", + [ + ( + "datasets", + "123456", + "us-central1", + "projects/123456/locations/us-central1/datasets/987654", + ), + ( + "trainingPipelines", + "857392", + "us-west20", + "projects/857392/locations/us-central1/trainingPipelines/347292", + ), + ], +) +def test_full_resource_name_with_full_name( + resource_noun: str, project: str, location: str, full_name: str, +): + # should ignore issues with other arguments as resource_name is full_name + assert ( + aiplatform.utils.full_resource_name( + resource_name=full_name, + resource_noun=resource_noun, + project=project, + location=location, + ) + == full_name + ) + + +@pytest.mark.parametrize( + "partial_name, resource_noun, project, location, full_name", + [ + ( + "987654", + "datasets", + "123456", + "us-central1", + "projects/123456/locations/us-central1/datasets/987654", + ), + ( + "347292", + "trainingPipelines", + "857392", + "us-central1", + "projects/857392/locations/us-central1/trainingPipelines/347292", + ), + ], +) +def test_full_resource_name_with_partial_name( + partial_name: str, resource_noun: str, project: str, location: str, full_name: str, +): + assert ( + aiplatform.utils.full_resource_name( + resource_name=partial_name, + resource_noun=resource_noun, + project=project, + location=location, + ) + == full_name + ) + + +@pytest.mark.parametrize( + "partial_name, resource_noun, project, location", + [("347292", "trainingPipelines", "857392", "us-west2020")], +) +def test_full_resource_name_raises_value_error( + partial_name: str, resource_noun: str, project: str, location: str, +): + with pytest.raises(ValueError): + aiplatform.utils.full_resource_name( + resource_name=partial_name, + resource_noun=resource_noun, + project=project, + location=location, + ) + + +def test_validate_display_name_raises_length(): + with pytest.raises(ValueError): + aiplatform.utils.validate_display_name( + "slanflksdnlikh;likhq290u90rflkasndfkljashndfkl;jhowq2342;iehoiwerhowqihjer34564356o;iqwjr;oijsdalfjasl;kfjas;ldifhja;slkdfsdlkfhj" + ) + + +def test_validate_display_name(): + aiplatform.utils.validate_display_name("my_model_abc") + + +@pytest.mark.parametrize( + "accelerator_type, expected", + [ + ("NVIDIA_TESLA_K80", True), + ("ACCELERATOR_TYPE_UNSPECIFIED", True), + ("NONEXISTENT_GPU", False), + ("NVIDIA_GALAXY_R7", False), + ("", False), + (None, False), + ], +) +def test_validate_accelerator_type(accelerator_type: str, expected: bool): + # Invalid type raises specific ValueError + if not expected: + with pytest.raises(ValueError) as e: + utils.validate_accelerator_type(accelerator_type) + assert e.match(regexp=r"Given accelerator_type") + # Valid type returns True + else: + assert utils.validate_accelerator_type(accelerator_type) + + +@pytest.mark.parametrize( + "gcs_path, expected", + [ + ("gs://example-bucket/path/to/folder", ("example-bucket", "path/to/folder")), + ("example-bucket/path/to/folder/", ("example-bucket", "path/to/folder")), + ("gs://example-bucket", ("example-bucket", None)), + ("gs://example-bucket/", ("example-bucket", None)), + ("gs://example-bucket/path", ("example-bucket", "path")), + ], +) +def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple): + # Given a GCS path, ensure correct bucket and prefix are extracted + assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path) + + +def test_wrapped_client(): + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + wrapped_client = utils.ClientWithOverride.WrappedClient( + client_class=model_service_client_default.ModelServiceClient, + client_options=test_client_options, + client_info=test_client_info, + ) + + assert isinstance( + wrapped_client.get_model.__self__, + model_service_client_default.ModelServiceClient, + ) + + +def test_client_w_override_default_version(): + + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + client_w_override = utils.ModelClientWithOverride( + client_options=test_client_options, client_info=test_client_info, + ) + assert isinstance( + client_w_override._clients[ + client_w_override._default_version + ].get_model.__self__, + model_service_client_default.ModelServiceClient, + ) + + +def test_client_w_override_select_version(): + + test_client_info = gapic_v1.client_info.ClientInfo() + test_client_options = client_options.ClientOptions() + + client_w_override = utils.ModelClientWithOverride( + client_options=test_client_options, client_info=test_client_info, + ) + + assert isinstance( + client_w_override.select_version(compat.V1BETA1).get_model.__self__, + model_service_client_v1beta1.ModelServiceClient, + ) + assert isinstance( + client_w_override.select_version(compat.V1).get_model.__self__, + model_service_client_v1.ModelServiceClient, + )