TPU ์Šฌ๋ผ์ด์Šค์—์„œ PyTorch ์ฝ”๋“œ ์‹คํ–‰

์ด ๋ฌธ์„œ์˜ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „ ๊ณ„์ • ๋ฐ Cloud TPU ํ”„๋กœ์ ํŠธ ์„ค์ •์˜ ์•ˆ๋‚ด๋ฅผ ๋”ฐ๋ฅด๋„๋ก ์œ ์˜ํ•˜์„ธ์š”.

๋‹จ์ผ TPU VM์—์„œ PyTorch ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•œ ํ›„์—๋Š” TPU ์Šฌ๋ผ์ด์Šค์—์„œ ์‹คํ–‰ํ•˜์—ฌ ์ฝ”๋“œ๋ฅผ ์ˆ˜์ง ํ™•์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. TPU ์Šฌ๋ผ์ด์Šค๋Š” ์ „์šฉ ๊ณ ์† ๋„คํŠธ์›Œํฌ ์—ฐ๊ฒฐ์„ ํ†ตํ•ด ์„œ๋กœ ์—ฐ๊ฒฐ๋œ ์—ฌ๋Ÿฌ TPU ๋ณด๋“œ์ž…๋‹ˆ๋‹ค. ์ด ๋ฌธ์„œ์—์„œ๋Š” TPU ์Šฌ๋ผ์ด์Šค์—์„œ PyTorch ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์†Œ๊ฐœํ•ฉ๋‹ˆ๋‹ค.

Cloud TPU ์Šฌ๋ผ์ด์Šค ๋งŒ๋“ค๊ธฐ

  1. ๋ช…๋ น์–ด๋ฅผ ๋” ์‰ฝ๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ๋ช‡ ๊ฐ€์ง€ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-32
    export RUNTIME_VERSION=v2-alpha-tpuv5

    ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค๋ช…

    ๋ณ€์ˆ˜ ์„ค๋ช…
    PROJECT_ID Google Cloud ํ”„๋กœ์ ํŠธ ID์ž…๋‹ˆ๋‹ค. ๊ธฐ์กด ํ”„๋กœ์ ํŠธ๋ฅผ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ์ƒˆ ํ”„๋กœ์ ํŠธ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
    TPU_NAME TPU์˜ ์ด๋ฆ„์ž…๋‹ˆ๋‹ค.
    ZONE TPU VM์„ ๋งŒ๋“ค ์˜์—ญ์ž…๋‹ˆ๋‹ค. ์ง€์›๋˜๋Š” ์˜์—ญ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ TPU ๋ฆฌ์ „ ๋ฐ ์˜์—ญ์„ ์ฐธ์กฐํ•˜์„ธ์š”.
    ACCELERATOR_TYPE ๊ฐ€์†๊ธฐ ์œ ํ˜•์€ ๋งŒ๋“ค๋ ค๋Š” Cloud TPU์˜ ๋ฒ„์ „๊ณผ ํฌ๊ธฐ๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ TPU ๋ฒ„์ „์—์„œ ์ง€์›๋˜๋Š” ๊ฐ€์†๊ธฐ ์œ ํ˜•์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ TPU ๋ฒ„์ „์„ ์ฐธ์กฐํ•˜์„ธ์š”.
    RUNTIME_VERSION Cloud TPU ์†Œํ”„ํŠธ์›จ์–ด ๋ฒ„์ „์ž…๋‹ˆ๋‹ค.

  2. ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์—ฌ TPU VM์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION}

์Šฌ๋ผ์ด์Šค์— PyTorch/XLA ์„ค์น˜

TPU ์Šฌ๋ผ์ด์Šค๋ฅผ ๋งŒ๋“  ํ›„ TPU ์Šฌ๋ผ์ด์Šค์—์„œ ๋ชจ๋“  ํ˜ธ์ŠคํŠธ์— PyTorch๋ฅผ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. --worker=all ๋ฐ --commamnd ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ด์„œ gcloud compute tpus tpu-vm ssh ๋ช…๋ น์–ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

SSH ์—ฐ๊ฒฐ ์˜ค๋ฅ˜๋กœ ์ธํ•ด ๋‹ค์Œ ๋ช…๋ น์–ด๊ฐ€ ์‹คํŒจํ•˜๋ฉด TPU VM์— ์™ธ๋ถ€ IP ์ฃผ์†Œ๊ฐ€ ์—†๊ธฐ ๋•Œ๋ฌธ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์™ธ๋ถ€ IP ์ฃผ์†Œ๊ฐ€ ์—†์ด TPU VM์— ์•ก์„ธ์Šคํ•˜๋ ค๋ฉด ๊ณต๊ฐœ IP ์ฃผ์†Œ ์—†์ด TPU VM์— ์—ฐ๊ฒฐ์˜ ์•ˆ๋‚ด๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.

  1. ๋ชจ๋“  TPU VM ์ž‘์—…์ž์— PyTorch/XLA๋ฅผ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. ๋ชจ๋“  TPU VM ์ž‘์—…์ž์—์„œ XLA๋ฅผ ํด๋ก ํ•ฉ๋‹ˆ๋‹ค.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="git clone https://github.com/pytorch/xla.git"

TPU ์Šฌ๋ผ์ด์Šค์—์„œ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰

๋ชจ๋“  ์ž‘์—…์ž์—์„œ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋Š” ๋‹จ์ผ ํ”„๋กœ๊ทธ๋žจ ๋‹ค์ค‘ ๋ฐ์ดํ„ฐ(SPMD) ์ƒค๋”ฉ ์ „๋žต์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. SPMD์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ PyTorch/XLA SPMD U์‚ฌ์šฉ์ž ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --worker=all \
   --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py  \
   --fake_data \
   --model=resnet50  \
   --num_epochs=1 2>&1 | tee ~/logs.txt"

์ด ํ•™์Šต์€ 15๋ถ„ ์ •๋„ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ์™„๋ฃŒ๋˜๋ฉด ๋‹ค์Œ๊ณผ ๋น„์Šทํ•œ ๋ฉ”์‹œ์ง€๊ฐ€ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

์‚ญ์ œ

TPU VM ์‚ฌ์šฉ์ด ์™„๋ฃŒ๋˜์—ˆ์œผ๋ฉด ๋‹ค์Œ ๋‹จ๊ณ„์— ๋”ฐ๋ผ ๋ฆฌ์†Œ์Šค๋ฅผ ์‚ญ์ œํ•˜์„ธ์š”.

  1. Cloud TPU ์ธ์Šคํ„ด์Šค์—์„œ ์•„์ง ์—ฐ๊ฒฐ์„ ํ•ด์ œํ•˜์ง€ ์•Š์•˜์œผ๋ฉด ์—ฐ๊ฒฐ์„ ํ•ด์ œํ•ฉ๋‹ˆ๋‹ค.

    (vm)$ exit

    ํ”„๋กฌํ”„ํŠธ๊ฐ€ username@projectname์œผ๋กœ ๋ฐ”๋€Œ๋ฉด Cloud Shell์— ์žˆ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  2. Cloud TPU ๋ฆฌ์†Œ์Šค๋ฅผ ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค.

    $ gcloud compute tpus tpu-vm delete  \
        --zone=${ZONE}
  3. gcloud compute tpus tpu-vm list๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ๋ฆฌ์†Œ์Šค๊ฐ€ ์‚ญ์ œ๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. ์‚ญ์ œํ•˜๋Š” ๋ฐ ๋ช‡ ๋ถ„ ์ •๋„ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋ช…๋ น์–ด์˜ ์ถœ๋ ฅ์—๋Š” ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ๋งŒ๋“  ๋ฆฌ์†Œ์Šค๊ฐ€ ํฌํ•จ๋˜์–ด์„œ๋Š” ์•ˆ ๋ฉ๋‹ˆ๋‹ค.

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}