Vertex AI๋ TPU VM์ ์ฌ์ฉํ์ฌ ๋ค์ํ ํ๋ ์์ํฌ ๋ฐ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ํ์ต์ ์ง์ํฉ๋๋ค. ์ปดํจํ ๋ฆฌ์์ค๋ฅผ ๊ตฌ์ฑํ ๋ TPU v2, TPU v3, TPU v5e VM์ ์ง์ ํ ์ ์์ต๋๋ค. TPU v5e๋ JAX 0.4.6 ์ด์, TensorFlow 2.15 ์ด์, PyTorch 2.1 ์ด์์ ์ง์ํฉ๋๋ค. TPU v6e๋ PJRT๋ฅผ ๊ธฐ๋ณธ ๋ฐํ์์ผ๋ก ์ฌ์ฉํ๋ Python 3.10 ์ด์, JAX 0.4.37 ์ด์, PyTorch 2.1 ์ด์์ ์ง์ํฉ๋๋ค.
์ปค์คํ ํ์ต์ ์ํ TPU VM ๊ตฌ์ฑ์ ๊ดํ ์์ธํ ๋ด์ฉ์ ์ปค์คํ ํ์ต์ ์ํ ์ปดํจํ ๋ฆฌ์์ค ๊ตฌ์ฑ์ ์ฐธ๊ณ ํ์ธ์.
TensorFlow ํ์ต
์ฌ์ ๋น๋๋ ์ปจํ ์ด๋
TPU๋ฅผ ์ง์ํ๋ ์ฌ์ ๋น๋๋ ํ์ต ์ปจํ ์ด๋๋ฅผ ์ฌ์ฉํ๊ณ Python ํ์ต ์ ํ๋ฆฌ์ผ์ด์ ์ ๋ง๋ญ๋๋ค.
์ปค์คํ ์ปจํ ์ด๋
TPU VM์ฉ์ผ๋ก ํน๋ณํ ๋น๋๋ tensorflow
๋ฐ libtpu
๋ฒ์ ์ด ์ค์น๋ ์ปค์คํ
์ปจํ
์ด๋๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด๋ฌํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ Cloud TPU ์๋น์ค์์ ์ ์ง ๊ด๋ฆฌ๋๊ณ ์ง์๋๋ TPU ๊ตฌ์ฑ ๋ฌธ์์ ๋์ด๋์ด ์์ต๋๋ค.
์ํ๋ tensorflow
๋ฒ์ ๊ณผ ํด๋น libtpu
๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ ํํฉ๋๋ค. ๊ทธ๋ฐ ํ ์ปจํ
์ด๋๋ฅผ ๋น๋ํ ๋ Docker ์ปจํ
์ด๋ ์ด๋ฏธ์ง์ ์ด๋ฅผ ์ค์นํฉ๋๋ค.
์๋ฅผ ๋ค์ด TensorFlow 2.12๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด Dockerfile์ ๋ค์ ์๋ด๋ฅผ ํฌํจํฉ๋๋ค.
# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so
# TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
# runtime, configure the following environment variables in your Dockerfile.
# For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
# ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
# ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
TPU Pod
TPU Pod
์์ tensorflow
๋ฅผ ํ์ต์ํค๋ ค๋ฉด ํ์ต ์ปจํ
์ด๋์ ์ถ๊ฐ ์ค์ ์ด ํ์ํฉ๋๋ค. Vertex AI๋ ์ด๊ธฐ ์ค์ ์ ์ฒ๋ฆฌํ๋ ๊ธฐ๋ณธ Docker ์ด๋ฏธ์ง๋ฅผ ์ ์งํฉ๋๋ค.
์ด๋ฏธ์ง URI | Python ๋ฒ์ ๋ฐ TPU ๋ฒ์ |
---|---|
|
3.8 |
|
3.10 |
์ปค์คํ ์ปจํ ์ด๋๋ฅผ ๋น๋ํ๋ ๋จ๊ณ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ์ ํํ Python ๋ฒ์ ์ ๊ธฐ๋ณธ ์ด๋ฏธ์ง๋ฅผ ์ ํํฉ๋๋ค. TensorFlow 2.12 ์ดํ๋ฅผ ์ํ TPU TensorFlow ํ ์ Python 3.8์ ์ง์ํฉ๋๋ค. TensorFlow 2.13 ์ด์์ Python 3.10 ์ด์์ ์ง์ํฉ๋๋ค. ํน์ TensorFlow ํ ์ ๊ฒฝ์ฐ์๋ Cloud TPU ๊ตฌ์ฑ์ ์ฐธ์กฐํ์ธ์.
- ํธ๋ ์ด๋ ์ฝ๋ ๋ฐ ์์ ๋ช ๋ น์ด๋ก ์ด๋ฏธ์ง๋ฅผ ํ์ฅํฉ๋๋ค.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
WORKDIR /root
# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl
# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so
# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py
# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]
PyTorch ํ์ต
TPU๋ก ํ์ตํ ๋ PyTorch์ ์ฌ์ ๋น๋๋ ์ปจํ ์ด๋ ๋๋ ์ปค์คํ ์ปจํ ์ด๋๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ฌ์ ๋น๋๋ ์ปจํ ์ด๋
TPU๋ฅผ ์ง์ํ๋ ์ฌ์ ๋น๋๋ ํ์ต ์ปจํ ์ด๋๋ฅผ ์ฌ์ฉํ๊ณ Python ํ์ต ์ ํ๋ฆฌ์ผ์ด์ ์ ๋ง๋ญ๋๋ค.
์ปค์คํ ์ปจํ ์ด๋
PyTorch
๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋ ์ปค์คํ
์ปจํ
์ด๋๋ฅผ ์ฌ์ฉํฉ๋๋ค.
์๋ฅผ ๋ค์ด Dockerfile์ ๋ค์๊ณผ ๊ฐ์ด ๋ณด์ผ ์ ์์ต๋๋ค.
FROM python:3.10
# v5e, v6e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU
# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
-f https://storage.googleapis.com/libtpu-releases/index.html
# Add your artifacts here
COPY trainer.py .
# Run the trainer code
CMD ["python3", "trainer.py"]
TPU Pod
ํ์ต์ TPU Pod์ ๋ชจ๋ ํธ์คํธ์์ ์คํ๋ฉ๋๋ค(TPU Pod ์ฌ๋ผ์ด์ค์์ PyTorch ์ฝ๋ ์คํ ์ฐธ์กฐ).
Vertex AI๋ ๋ชจ๋ ํธ์คํธ์ ์๋ต์ ๊ธฐ๋ค๋ ค ์์ ์๋ฃ๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
JAX ํ์ต
์ฌ์ ๋น๋๋ ์ปจํ ์ด๋
JAX์๋ ์ฌ์ ๋น๋๋ ์ปจํ ์ด๋๊ฐ ์์ต๋๋ค.
์ปค์คํ ์ปจํ ์ด๋
JAX
๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋ ์ปค์คํ
์ปจํ
์ด๋๋ฅผ ์ฌ์ฉํฉ๋๋ค.
์๋ฅผ ๋ค์ด Dockerfile์ ๋ค์๊ณผ ๊ฐ์ด ๋ณด์ผ ์ ์์ต๋๋ค.
# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Add your artifacts here
COPY trainer.py trainer.py
# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]
TPU Pod
ํ์ต์ TPU Pod์ ๋ชจ๋ ํธ์คํธ์์ ์คํ๋ฉ๋๋ค(TPU Pod ์ฌ๋ผ์ด์ค์์ JAX ์ฝ๋ ์คํ ์ฐธ์กฐ).
Vertex AI๋ TPU Pod์ ์ฒซ ๋ฒ์งธ ํธ์คํธ๋ฅผ ๊ฐ์ํ์ฌ ์์ ์๋ฃ๋ฅผ ๊ฒฐ์ ํฉ๋๋ค. ๋ค์ ์ฝ๋ ์ค๋ํซ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ ํธ์คํธ๊ฐ ๋์์ ์ข ๋ฃ๋๋์ง ํ์ธํ ์ ์์ต๋๋ค.
# Your training logic
...
if jax.process_count() > 1:
# Make sure all hosts stay up until the end of main.
x = jnp.ones([jax.local_device_count()])
x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
assert x[0] == jax.device_count()
ํ๊ฒฝ ๋ณ์
๋ค์ ํ์์๋ ์ปจํ ์ด๋ ๋ด์์ ์ฌ์ฉํ ์ ์๋ ํ๊ฒฝ ๋ณ์์ ๋ํด ์์ธํ ์ค๋ช ํฉ๋๋ค.
์ด๋ฆ | ๊ฐ |
---|---|
TPU_NODE_NAME | my-first-tpu-node |
TPU_CONFIG | {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"} |
์ปค์คํ ์๋น์ค ๊ณ์
์ปค์คํ ์๋น์ค ๊ณ์ ์ TPU ํ์ต์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ปค์คํ ์๋น์ค ๊ณ์ ์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ปค์คํ ์๋น์ค ๊ณ์ ์ฌ์ฉ ๋ฐฉ๋ฒ ํ์ด์ง๋ฅผ ์ฐธ๊ณ ํ์ธ์.
ํ์ต์ฉ ๋น๊ณต๊ฐ IP(VPC ๋คํธ์ํฌ ํผ์ด๋ง)
๋น๊ณต๊ฐ IP๋ฅผ TPU ํ์ต์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ปค์คํ ํ์ต์ ๋น๊ณต๊ฐ IP๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ ํ์ด์ง๋ฅผ ์ฐธ๊ณ ํ์ธ์.
VPC ์๋น์ค ์ ์ด
VPC ์๋น์ค ์ ์ด๊ฐ ์ฌ์ฉ ์ค์ ๋ ํ๋ก์ ํธ๋ TPU ํ์ต ์์ ์ ์ ์ถํ ์ ์์ต๋๋ค.
์ ํ์ฌํญ
TPU VM์ ์ฌ์ฉํ์ฌ ํ์ตํ ๋ ๋ค์ ์ ํ ์ฌํญ์ด ์ ์ฉ๋ฉ๋๋ค.
TPU ์ ํ
๋ฉ๋ชจ๋ฆฌ ํ๋์ ๊ฐ์ด TPU ๊ฐ์๊ธฐ์ ๋ํ ์์ธํ ๋ด์ฉ์ TPU ์ ํ์ ์ฐธ์กฐํ์ธ์.