์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ํ•™์Šต

์ด ํŽ˜์ด์ง€์—์„œ๋Š” Google Cloud ์ฝ˜์†” ๋˜๋Š” Vertex AI API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ AutoML ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

AutoML ๋ชจ๋ธ ํ•™์Šต

Google Cloud ์ฝ˜์†”

  1. Google Cloud ์ฝ˜์†”์˜ Vertex AI ์„น์…˜์—์„œ ๋ฐ์ดํ„ฐ ์„ธํŠธ ํŽ˜์ด์ง€๋กœ ์ด๋™ํ•ฉ๋‹ˆ๋‹ค.

    ๋ฐ์ดํ„ฐ ์„ธํŠธ ํŽ˜์ด์ง€๋กœ ์ด๋™

  2. ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐ ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ด๋ฆ„์„ ํด๋ฆญํ•˜์—ฌ ์„ธ๋ถ€์ •๋ณด ํŽ˜์ด์ง€๋ฅผ ์—ฝ๋‹ˆ๋‹ค.

  3. ์ƒˆ ๋ชจ๋ธ ํ•™์Šต์„ ํด๋ฆญํ•ฉ๋‹ˆ๋‹ค.

  4. ํ•™์Šต ๋ฉ”์„œ๋“œ๋กœ AutoML์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

  5. ๊ณ„์†์„ ํด๋ฆญํ•ฉ๋‹ˆ๋‹ค.

  6. ๋ชจ๋ธ์˜ ์ด๋ฆ„์„ ์ž…๋ ฅํ•ฉ๋‹ˆ๋‹ค.

  7. ํ•™์Šต ๋ฐ์ดํ„ฐ ๋ถ„ํ•  ๋ฐฉ๋ฒ•์„ ์ˆ˜๋™์œผ๋กœ ์„ค์ •ํ•˜๋ ค๋ฉด ๊ณ ๊ธ‰ ์˜ต์…˜์„ ํŽผ์น˜๊ณ  ๋ฐ์ดํ„ฐ ๋ถ„ํ•  ์˜ต์…˜์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ์ž์„ธํžˆ ์•Œ์•„๋ณด๊ธฐ

  8. ํ•™์Šต ์‹œ์ž‘์„ ํด๋ฆญํ•ฉ๋‹ˆ๋‹ค.

    ๋ฐ์ดํ„ฐ์˜ ๊ทœ๋ชจ ๋ฐ ๋ณต์žก์„ฑ๊ณผ ํ•™์Šต ์˜ˆ์‚ฐ(์ง€์ •ํ•œ ๊ฒฝ์šฐ)์— ๋”ฐ๋ผ ๋ชจ๋ธ ํ•™์Šต์— ๋งŽ์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํƒญ์„ ๋‹ซ์•˜๋‹ค๊ฐ€ ๋‚˜์ค‘์— ๋‹ค์‹œ ๋Œ์•„์™€๋„ ๋ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด ์ด๋ฉ”์ผ์ด ์ „์†ก๋ฉ๋‹ˆ๋‹ค.

API

์•„๋ž˜์—์„œ ๋ชฉํ‘œ์— ๋Œ€ํ•œ ํƒญ์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

๋ถ„๋ฅ˜

์•„๋ž˜์—์„œ ์–ธ์–ด ๋˜๋Š” ํ™˜๊ฒฝ์— ๋Œ€ํ•œ ํƒญ์„ ์„ ํƒํ•˜์„ธ์š”.

REST

์š”์ฒญ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์ „์— ๋‹ค์Œ์„ ๋ฐ”๊ฟ‰๋‹ˆ๋‹ค.

  • LOCATION: ๋ฐ์ดํ„ฐ ์„ธํŠธ๊ฐ€ ์žˆ๊ณ  ๋ชจ๋ธ์ด ์ƒ์„ฑ๋œ ๋ฆฌ์ „์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด us-central1์ž…๋‹ˆ๋‹ค.
  • PROJECT: ํ”„๋กœ์ ํŠธ ID์ž…๋‹ˆ๋‹ค.
  • TRAININGPIPELINE_DISPLAYNAME: ํ•„์ˆ˜. trainingPipeline์˜ ํ‘œ์‹œ ์ด๋ฆ„์ž…๋‹ˆ๋‹ค.
  • DATASET_ID: ํ•™์Šต์— ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ID ๋ฒˆํ˜ธ์ž…๋‹ˆ๋‹ค.
  • fractionSplit: ์„ ํƒ์‚ฌํ•ญ. ๊ฐ€๋Šฅํ•œ ์—ฌ๋Ÿฌ ML ์ค‘ ํ•˜๋‚˜๊ฐ€ ๋ฐ์ดํ„ฐ์— ๋ถ„ํ•  ์˜ต์…˜์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. fractionSplit์˜ ๊ฒฝ์šฐ ๊ฐ’์€ ํ•ฉ๊ณ„๊ฐ€ 1์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME*: TrainingPipeline์—์„œ ์—…๋กœ๋“œ(์ƒ์„ฑ)ํ•œ ๋ชจ๋ธ์˜ ํ‘œ์‹œ ์ด๋ฆ„์ž…๋‹ˆ๋‹ค.
  • MODEL_DESCRIPTION*: ๋ชจ๋ธ์— ๋Œ€ํ•œ ์„ค๋ช…์ž…๋‹ˆ๋‹ค.
  • modelToUpload.labels*: ๋ชจ๋ธ์„ ๊ตฌ์„ฑํ•  ๋ชจ๋“  ํ‚ค-๊ฐ’ ์Œ ์„ธํŠธ์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
    • "env": "prod"
    • "tier": "backend"
  • MODELTYPEโ€ : ํ•™์Šต์‹œํ‚ฌ ํด๋ผ์šฐ๋“œ ํ˜ธ์ŠคํŒ… ๋ชจ๋ธ์˜ ์œ ํ˜•์ž…๋‹ˆ๋‹ค. ์˜ต์…˜์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
    • CLOUD(๊ธฐ๋ณธ)
  • NODE_HOUR_BUDGETโ€ : ์‹ค์ œ ํ•™์Šต ๋น„์šฉ์€ ์ด ๊ฐ’๋ณด๋‹ค ์ž‘๊ฑฐ๋‚˜ ๊ฐ™์Šต๋‹ˆ๋‹ค. Cloud ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ์˜ˆ์‚ฐ์€ 8,000~800,000๋ฐ€๋ฆฌ ๋…ธ๋“œ ์‹œ๊ฐ„์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค(8,000, 800,000 ํฌํ•จ). ๊ธฐ๋ณธ๊ฐ’์€ ์‹ค์ œ ๊ฒฝ๊ณผ ์‹œ๊ฐ„์œผ๋กœ 1์ผ์„ ๋‚˜ํƒ€๋‚ด๋Š” 192,000์ด๋ฉฐ, 8๊ฐœ์˜ ๋…ธ๋“œ๊ฐ€ ์‚ฌ์šฉ๋˜์—ˆ์Œ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
  • PROJECT_NUMBER: ํ”„๋กœ์ ํŠธ์˜ ์ž๋™์œผ๋กœ ์ƒ์„ฑ๋œ ํ”„๋กœ์ ํŠธ ๋ฒˆํ˜ธ

HTTP ๋ฉ”์„œ๋“œ ๋ฐ URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON ์š”์ฒญ ๋ณธ๋ฌธ:

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "false",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

์š”์ฒญ์„ ๋ณด๋‚ด๋ ค๋ฉด ๋‹ค์Œ ์˜ต์…˜ ์ค‘ ํ•˜๋‚˜๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

curl

์š”์ฒญ ๋ณธ๋ฌธ์„ request.json ํŒŒ์ผ์— ์ €์žฅํ•˜๊ณ  ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

์š”์ฒญ ๋ณธ๋ฌธ์„ request.json ํŒŒ์ผ์— ์ €์žฅํ•˜๊ณ  ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

์‘๋‹ต์—๋Š” ์‚ฌ์–‘ ๋ฐ TRAININGPIPELINE_ID์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ€ ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

Java

์ด ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ด ๋ณด๊ธฐ ์ „์— Vertex AI ๋น ๋ฅธ ์‹œ์ž‘: ํด๋ผ์ด์–ธํŠธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์‚ฌ์šฉ์˜ Java ์„ค์ • ์•ˆ๋‚ด๋ฅผ ๋”ฐ๋ฅด์„ธ์š”. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Vertex AI Java API ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

Vertex AI์— ์ธ์ฆํ•˜๋ ค๋ฉด ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๊ธฐ๋ณธ ์‚ฌ์šฉ์ž ์ธ์ฆ ์ •๋ณด๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋กœ์ปฌ ๊ฐœ๋ฐœ ํ™˜๊ฒฝ์˜ ์ธ์ฆ ์„ค์ •์„ ์ฐธ์กฐํ•˜์„ธ์š”.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

์ด ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ด ๋ณด๊ธฐ ์ „์— Vertex AI ๋น ๋ฅธ ์‹œ์ž‘: ํด๋ผ์ด์–ธํŠธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์‚ฌ์šฉ์˜ Node.js ์„ค์ • ์•ˆ๋‚ด๋ฅผ ๋”ฐ๋ฅด์„ธ์š”. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Vertex AI Node.js API ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

Vertex AI์— ์ธ์ฆํ•˜๋ ค๋ฉด ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๊ธฐ๋ณธ ์‚ฌ์šฉ์ž ์ธ์ฆ ์ •๋ณด๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋กœ์ปฌ ๊ฐœ๋ฐœ ํ™˜๊ฒฝ์˜ ์ธ์ฆ ์„ค์ •์„ ์ฐธ์กฐํ•˜์„ธ์š”.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Vertex AI SDK for Python์„ ์„ค์น˜ํ•˜๊ฑฐ๋‚˜ ์—…๋ฐ์ดํŠธํ•˜๋Š” ๋ฐฉ๋ฒ•์€ Vertex AI SDK for Python ์„ค์น˜๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Python API ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

๋ถ„๋ฅ˜

์•„๋ž˜์—์„œ ์–ธ์–ด ๋˜๋Š” ํ™˜๊ฒฝ์— ๋Œ€ํ•œ ํƒญ์„ ์„ ํƒํ•˜์„ธ์š”.

REST

์š”์ฒญ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์ „์— ๋‹ค์Œ์„ ๋ฐ”๊ฟ‰๋‹ˆ๋‹ค.

  • LOCATION: ๋ฐ์ดํ„ฐ ์„ธํŠธ๊ฐ€ ์žˆ๊ณ  ๋ชจ๋ธ์ด ์ƒ์„ฑ๋œ ๋ฆฌ์ „์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด us-central1์ž…๋‹ˆ๋‹ค.
  • PROJECT: ํ”„๋กœ์ ํŠธ ID์ž…๋‹ˆ๋‹ค.
  • TRAININGPIPELINE_DISPLAYNAME: ํ•„์ˆ˜. trainingPipeline์˜ ํ‘œ์‹œ ์ด๋ฆ„์ž…๋‹ˆ๋‹ค.
  • DATASET_ID: ํ•™์Šต์— ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ID ๋ฒˆํ˜ธ์ž…๋‹ˆ๋‹ค.
  • fractionSplit: ์„ ํƒ์‚ฌํ•ญ. ๊ฐ€๋Šฅํ•œ ์—ฌ๋Ÿฌ ML ์ค‘ ํ•˜๋‚˜๊ฐ€ ๋ฐ์ดํ„ฐ์— ๋ถ„ํ•  ์˜ต์…˜์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. fractionSplit์˜ ๊ฒฝ์šฐ ๊ฐ’์€ ํ•ฉ๊ณ„๊ฐ€ 1์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME*: TrainingPipeline์—์„œ ์—…๋กœ๋“œ(์ƒ์„ฑ)ํ•œ ๋ชจ๋ธ์˜ ํ‘œ์‹œ ์ด๋ฆ„์ž…๋‹ˆ๋‹ค.
  • MODEL_DESCRIPTION*: ๋ชจ๋ธ์— ๋Œ€ํ•œ ์„ค๋ช…์ž…๋‹ˆ๋‹ค.
  • modelToUpload.labels*: ๋ชจ๋ธ์„ ๊ตฌ์„ฑํ•  ๋ชจ๋“  ํ‚ค-๊ฐ’ ์Œ ์„ธํŠธ์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
    • "env": "prod"
    • "tier": "backend"
  • MODELTYPEโ€ : ํ•™์Šต์‹œํ‚ฌ ํด๋ผ์šฐ๋“œ ํ˜ธ์ŠคํŒ… ๋ชจ๋ธ์˜ ์œ ํ˜•์ž…๋‹ˆ๋‹ค. ์˜ต์…˜์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
    • CLOUD(๊ธฐ๋ณธ)
  • NODE_HOUR_BUDGETโ€ : ์‹ค์ œ ํ•™์Šต ๋น„์šฉ์€ ์ด ๊ฐ’๋ณด๋‹ค ์ž‘๊ฑฐ๋‚˜ ๊ฐ™์Šต๋‹ˆ๋‹ค. Cloud ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ์˜ˆ์‚ฐ์€ 8,000~800,000๋ฐ€๋ฆฌ ๋…ธ๋“œ ์‹œ๊ฐ„์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค(8,000, 800,000 ํฌํ•จ). ๊ธฐ๋ณธ๊ฐ’์€ ์‹ค์ œ ๊ฒฝ๊ณผ ์‹œ๊ฐ„์œผ๋กœ 1์ผ์„ ๋‚˜ํƒ€๋‚ด๋Š” 192,000์ด๋ฉฐ, 8๊ฐœ์˜ ๋…ธ๋“œ๊ฐ€ ์‚ฌ์šฉ๋˜์—ˆ์Œ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
  • PROJECT_NUMBER: ํ”„๋กœ์ ํŠธ์˜ ์ž๋™์œผ๋กœ ์ƒ์„ฑ๋œ ํ”„๋กœ์ ํŠธ ๋ฒˆํ˜ธ

HTTP ๋ฉ”์„œ๋“œ ๋ฐ URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON ์š”์ฒญ ๋ณธ๋ฌธ:

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "true",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

์š”์ฒญ์„ ๋ณด๋‚ด๋ ค๋ฉด ๋‹ค์Œ ์˜ต์…˜ ์ค‘ ํ•˜๋‚˜๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

curl

์š”์ฒญ ๋ณธ๋ฌธ์„ request.json ํŒŒ์ผ์— ์ €์žฅํ•˜๊ณ  ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

์š”์ฒญ ๋ณธ๋ฌธ์„ request.json ํŒŒ์ผ์— ์ €์žฅํ•˜๊ณ  ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

์‘๋‹ต์—๋Š” ์‚ฌ์–‘ ๋ฐ TRAININGPIPELINE_ID์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ€ ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

Java

์ด ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ด ๋ณด๊ธฐ ์ „์— Vertex AI ๋น ๋ฅธ ์‹œ์ž‘: ํด๋ผ์ด์–ธํŠธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์‚ฌ์šฉ์˜ Java ์„ค์ • ์•ˆ๋‚ด๋ฅผ ๋”ฐ๋ฅด์„ธ์š”. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Vertex AI Java API ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

Vertex AI์— ์ธ์ฆํ•˜๋ ค๋ฉด ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๊ธฐ๋ณธ ์‚ฌ์šฉ์ž ์ธ์ฆ ์ •๋ณด๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋กœ์ปฌ ๊ฐœ๋ฐœ ํ™˜๊ฒฝ์˜ ์ธ์ฆ ์„ค์ •์„ ์ฐธ์กฐํ•˜์„ธ์š”.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

์ด ์ƒ˜ํ”Œ์„ ์‚ฌ์šฉํ•ด ๋ณด๊ธฐ ์ „์— Vertex AI ๋น ๋ฅธ ์‹œ์ž‘: ํด๋ผ์ด์–ธํŠธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์‚ฌ์šฉ์˜ Node.js ์„ค์ • ์•ˆ๋‚ด๋ฅผ ๋”ฐ๋ฅด์„ธ์š”. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Vertex AI Node.js API ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

Vertex AI์— ์ธ์ฆํ•˜๋ ค๋ฉด ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ๊ธฐ๋ณธ ์‚ฌ์šฉ์ž ์ธ์ฆ ์ •๋ณด๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋กœ์ปฌ ๊ฐœ๋ฐœ ํ™˜๊ฒฝ์˜ ์ธ์ฆ ์„ค์ •์„ ์ฐธ์กฐํ•˜์„ธ์š”.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Vertex AI SDK for Python์„ ์„ค์น˜ํ•˜๊ฑฐ๋‚˜ ์—…๋ฐ์ดํŠธํ•˜๋Š” ๋ฐฉ๋ฒ•์€ Vertex AI SDK for Python ์„ค์น˜๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”. ์ž์„ธํ•œ ๋‚ด์šฉ์€ Python API ์ฐธ๊ณ  ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

REST๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๋ถ„ํ•  ์ œ์–ด

ํ•™์Šต ๋ฐ์ดํ„ฐ๊ฐ€ ํ•™์Šต, ๊ฒ€์ฆ, ํ…Œ์ŠคํŠธ ์„ธํŠธ ๊ฐ„์— ๋ถ„ํ• ๋˜๋Š” ๋ฐฉ์‹์„ ์ œ์–ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Vertex AI API๋ฅผ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ Split ๊ฐ์ฒด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๋ถ„ํ• ์„ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค. Split ๊ฐ์ฒด๋Š” InputConfig ๊ฐ์ฒด์— ์—ฌ๋Ÿฌ ๊ฐ์ฒด ์œ ํ˜• ์ค‘ ํ•˜๋‚˜๋กœ ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉฐ, ๊ฐ ๊ฐ์ฒด๋Š” ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋ถ„ํ• ํ•˜๋Š” ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ํ•˜๋‚˜์˜ ๋ฐฉ๋ฒ•๋งŒ ์„ ํƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • FractionSplit:
    • TRAINING_FRACTION: ํ•™์Šต ์„ธํŠธ์— ์‚ฌ์šฉํ•  ํ•™์Šต ๋ฐ์ดํ„ฐ์˜ ๋น„์œจ์ž…๋‹ˆ๋‹ค.
    • VALIDATION_FRACTION: ๊ฒ€์ฆ ์„ธํŠธ์— ์‚ฌ์šฉํ•  ํ•™์Šต ๋ฐ์ดํ„ฐ์˜ ๋น„์œจ์ž…๋‹ˆ๋‹ค. ๋™์˜์ƒ ๋ฐ์ดํ„ฐ์—๋Š” ์‚ฌ์šฉ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
    • TEST_FRACTION: ํ…Œ์ŠคํŠธ ์„ธํŠธ์— ์‚ฌ์šฉํ•  ํ•™์Šต ๋ฐ์ดํ„ฐ์˜ ๋น„์œจ์ž…๋‹ˆ๋‹ค.

    ๋น„์œจ ์ค‘ ํ•˜๋‚˜๋ผ๋„ ์ง€์ •๋œ ๊ฒฝ์šฐ ๋ชจ๋‘ ์ง€์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋น„์œจ์˜ ํ•ฉ์€ 1.0์ด ๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋น„์œจ์— ๋Œ€ํ•œ ๊ธฐ๋ณธ๊ฐ’์€ ๋ฐ์ดํ„ฐ ์œ ํ˜•์— ๋”ฐ๋ผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. ์ž์„ธํžˆ ์•Œ์•„๋ณด๊ธฐ

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
    },
    
  • FilterSplit:
    • TRAINING_FILTER: ์ด ํ•„ํ„ฐ์™€ ์ผ์น˜ํ•˜๋Š” ๋ฐ์ดํ„ฐ ํ•ญ๋ชฉ์€ ํ•™์Šต ์„ธํŠธ์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
    • VALIDATION_FILTER: ์ด ํ•„ํ„ฐ์™€ ์ผ์น˜ํ•˜๋Š” ๋ฐ์ดํ„ฐ ํ•ญ๋ชฉ์€ ๊ฒ€์ฆ ์„ธํŠธ์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ๋™์˜์ƒ ๋ฐ์ดํ„ฐ๋Š” '-'์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค.
    • TEST_FILTER: ์ด ํ•„ํ„ฐ์™€ ์ผ์น˜ํ•˜๋Š” ๋ฐ์ดํ„ฐ ํ•ญ๋ชฉ์€ ํ…Œ์ŠคํŠธ ์„ธํŠธ์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

    ์ด ํ•„ํ„ฐ๋Š” ml_use ๋ผ๋ฒจ ๋˜๋Š” ๋ฐ์ดํ„ฐ์— ์ ์šฉํ•˜๋Š” ๋ชจ๋“  ๋ผ๋ฒจ๊ณผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ml-use ๋ผ๋ฒจ๊ณผ ๊ธฐํƒ€ ๋ผ๋ฒจ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋ฅผ ํ•„ํ„ฐ๋งํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ž์„ธํžˆ ์•Œ์•„๋ณด์„ธ์š”.

    ๋‹ค์Œ ์˜ˆ์‹œ์—์„œ๋Š” ๊ฒ€์ฆ ์„ธํŠธ๊ฐ€ ํฌํ•จ๋œ ml_use ๋ผ๋ฒจ๊ณผ ํ•จ๊ป˜ filterSplit ๊ฐ์ฒด๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

    "filterSplit": {
    "trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
    "validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
    "testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
    }