TPU ๊ฐ€์†๊ธฐ๋ฅผ ์‚ฌ์šฉํ•œ ํ•™์Šต

Vertex AI๋Š” TPU VM์„ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์–‘ํ•œ ํ”„๋ ˆ์ž„์›Œํฌ ๋ฐ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ ํ•™์Šต์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ์ปดํ“จํŒ… ๋ฆฌ์†Œ์Šค๋ฅผ ๊ตฌ์„ฑํ•  ๋•Œ TPU v2, TPU v3, TPU v5e VM์„ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. TPU v5e๋Š” JAX 0.4.6 ์ด์ƒ, TensorFlow 2.15 ์ด์ƒ, PyTorch 2.1 ์ด์ƒ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. TPU v6e๋Š” PJRT๋ฅผ ๊ธฐ๋ณธ ๋Ÿฐํƒ€์ž„์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” Python 3.10 ์ด์ƒ, JAX 0.4.37 ์ด์ƒ, PyTorch 2.1 ์ด์ƒ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.

์ปค์Šคํ…€ ํ•™์Šต์„ ์œ„ํ•œ TPU VM ๊ตฌ์„ฑ์— ๊ด€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ปค์Šคํ…€ ํ•™์Šต์„ ์œ„ํ•œ ์ปดํ“จํŒ… ๋ฆฌ์†Œ์Šค ๊ตฌ์„ฑ์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

TensorFlow ํ•™์Šต

์‚ฌ์ „ ๋นŒ๋“œ๋œ ์ปจํ…Œ์ด๋„ˆ

TPU๋ฅผ ์ง€์›ํ•˜๋Š” ์‚ฌ์ „ ๋นŒ๋“œ๋œ ํ•™์Šต ์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  Python ํ•™์Šต ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ

TPU VM์šฉ์œผ๋กœ ํŠน๋ณ„ํžˆ ๋นŒ๋“œ๋œ tensorflow ๋ฐ libtpu ๋ฒ„์ „์ด ์„ค์น˜๋œ ์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” Cloud TPU ์„œ๋น„์Šค์—์„œ ์œ ์ง€ ๊ด€๋ฆฌ๋˜๊ณ  ์ง€์›๋˜๋Š” TPU ๊ตฌ์„ฑ ๋ฌธ์„œ์— ๋‚˜์—ด๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

์›ํ•˜๋Š” tensorflow ๋ฒ„์ „๊ณผ ํ•ด๋‹น libtpu ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ํ›„ ์ปจํ…Œ์ด๋„ˆ๋ฅผ ๋นŒ๋“œํ•  ๋•Œ Docker ์ปจํ…Œ์ด๋„ˆ ์ด๋ฏธ์ง€์— ์ด๋ฅผ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด TensorFlow 2.12๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด Dockerfile์— ๋‹ค์Œ ์•ˆ๋‚ด๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.

  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

TPU Pod

TPU Pod์—์„œ tensorflow๋ฅผ ํ•™์Šต์‹œํ‚ค๋ ค๋ฉด ํ•™์Šต ์ปจํ…Œ์ด๋„ˆ์— ์ถ”๊ฐ€ ์„ค์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. Vertex AI๋Š” ์ดˆ๊ธฐ ์„ค์ •์„ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ธฐ๋ณธ Docker ์ด๋ฏธ์ง€๋ฅผ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.

์ด๋ฏธ์ง€ URI Python ๋ฒ„์ „ ๋ฐ TPU ๋ฒ„์ „
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
3.8
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu.2-15-pod-base-cp310:latest
3.10

์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ๋ฅผ ๋นŒ๋“œํ•˜๋Š” ๋‹จ๊ณ„๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

  1. ์„ ํƒํ•œ Python ๋ฒ„์ „์˜ ๊ธฐ๋ณธ ์ด๋ฏธ์ง€๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. TensorFlow 2.12 ์ดํ•˜๋ฅผ ์œ„ํ•œ TPU TensorFlow ํœ ์€ Python 3.8์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. TensorFlow 2.13 ์ด์ƒ์€ Python 3.10 ์ด์ƒ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. ํŠน์ • TensorFlow ํœ ์˜ ๊ฒฝ์šฐ์—๋Š” Cloud TPU ๊ตฌ์„ฑ์„ ์ฐธ์กฐํ•˜์„ธ์š”.
  2. ํŠธ๋ ˆ์ด๋„ˆ ์ฝ”๋“œ ๋ฐ ์‹œ์ž‘ ๋ช…๋ น์–ด๋กœ ์ด๋ฏธ์ง€๋ฅผ ํ™•์žฅํ•ฉ๋‹ˆ๋‹ค.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
WORKDIR /root

# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl

# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so

# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py

# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]

PyTorch ํ•™์Šต

TPU๋กœ ํ•™์Šตํ•  ๋•Œ PyTorch์— ์‚ฌ์ „ ๋นŒ๋“œ๋œ ์ปจํ…Œ์ด๋„ˆ ๋˜๋Š” ์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์ „ ๋นŒ๋“œ๋œ ์ปจํ…Œ์ด๋„ˆ

TPU๋ฅผ ์ง€์›ํ•˜๋Š” ์‚ฌ์ „ ๋นŒ๋“œ๋œ ํ•™์Šต ์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  Python ํ•™์Šต ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ

PyTorch ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋œ ์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด Dockerfile์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ณด์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

FROM python:3.10

# v5e, v6e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU

# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html

# Add your artifacts here
COPY trainer.py .

# Run the trainer code
CMD ["python3", "trainer.py"]

TPU Pod

ํ•™์Šต์€ TPU Pod์˜ ๋ชจ๋“  ํ˜ธ์ŠคํŠธ์—์„œ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค(TPU Pod ์Šฌ๋ผ์ด์Šค์—์„œ PyTorch ์ฝ”๋“œ ์‹คํ–‰ ์ฐธ์กฐ).

Vertex AI๋Š” ๋ชจ๋“  ํ˜ธ์ŠคํŠธ์˜ ์‘๋‹ต์„ ๊ธฐ๋‹ค๋ ค ์ž‘์—… ์™„๋ฃŒ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.

JAX ํ•™์Šต

์‚ฌ์ „ ๋นŒ๋“œ๋œ ์ปจํ…Œ์ด๋„ˆ

JAX์—๋Š” ์‚ฌ์ „ ๋นŒ๋“œ๋œ ์ปจํ…Œ์ด๋„ˆ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ

JAX ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋œ ์ปค์Šคํ…€ ์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด Dockerfile์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ณด์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Add your artifacts here
COPY trainer.py trainer.py

# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]

TPU Pod

ํ•™์Šต์€ TPU Pod์˜ ๋ชจ๋“  ํ˜ธ์ŠคํŠธ์—์„œ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค(TPU Pod ์Šฌ๋ผ์ด์Šค์—์„œ JAX ์ฝ”๋“œ ์‹คํ–‰ ์ฐธ์กฐ).

Vertex AI๋Š” TPU Pod์˜ ์ฒซ ๋ฒˆ์งธ ํ˜ธ์ŠคํŠธ๋ฅผ ๊ฐ์‹œํ•˜์—ฌ ์ž‘์—… ์™„๋ฃŒ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ์ฝ”๋“œ ์Šค๋‹ˆํŽซ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋“  ํ˜ธ์ŠคํŠธ๊ฐ€ ๋™์‹œ์— ์ข…๋ฃŒ๋˜๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

# Your training logic
...

if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()

ํ™˜๊ฒฝ ๋ณ€์ˆ˜

๋‹ค์Œ ํ‘œ์—์„œ๋Š” ์ปจํ…Œ์ด๋„ˆ ๋‚ด์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ํ™˜๊ฒฝ ๋ณ€์ˆ˜์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

์ด๋ฆ„ ๊ฐ’
TPU_NODE_NAME my-first-tpu-node
TPU_CONFIG {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"}

์ปค์Šคํ…€ ์„œ๋น„์Šค ๊ณ„์ •

์ปค์Šคํ…€ ์„œ๋น„์Šค ๊ณ„์ •์„ TPU ํ•™์Šต์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ปค์Šคํ…€ ์„œ๋น„์Šค ๊ณ„์ •์„ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ์ปค์Šคํ…€ ์„œ๋น„์Šค ๊ณ„์ • ์‚ฌ์šฉ ๋ฐฉ๋ฒ• ํŽ˜์ด์ง€๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

ํ•™์Šต์šฉ ๋น„๊ณต๊ฐœ IP(VPC ๋„คํŠธ์›Œํฌ ํ”ผ์–ด๋ง)

๋น„๊ณต๊ฐœ IP๋ฅผ TPU ํ•™์Šต์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ปค์Šคํ…€ ํ•™์Šต์— ๋น„๊ณต๊ฐœ IP๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ• ํŽ˜์ด์ง€๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

VPC ์„œ๋น„์Šค ์ œ์–ด

VPC ์„œ๋น„์Šค ์ œ์–ด๊ฐ€ ์‚ฌ์šฉ ์„ค์ •๋œ ํ”„๋กœ์ ํŠธ๋Š” TPU ํ•™์Šต ์ž‘์—…์„ ์ œ์ถœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ œํ•œ์‚ฌํ•ญ

TPU VM์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šตํ•  ๋•Œ ๋‹ค์Œ ์ œํ•œ ์‚ฌํ•ญ์ด ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.

TPU ์œ ํ˜•

๋ฉ”๋ชจ๋ฆฌ ํ•œ๋„์™€ ๊ฐ™์ด TPU ๊ฐ€์†๊ธฐ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ TPU ์œ ํ˜•์„ ์ฐธ์กฐํ•˜์„ธ์š”.