FSDP を使用して A4 Slurm クラスタで Llama 4 をファインチューニングする

このチュートリアルでは、 上のマルチノード、マルチ GPU Slurm クラスタで Llama-4-Scout-17 大規模言語 モデル(LLM)をファインチューニングする方法について説明します Google Cloud。このクラスタは 2 つの A4 仮想マシン(VM)インスタンスを使用し、各インスタンスには 8 個の NVIDIA B200 GPU が搭載されています。

このチュートリアルで説明する主なプロセスは次のとおりです。

  1. Cluster Toolkit を使用して、本番環境グレードの高性能 Slurm クラスタをデプロイします。このデプロイの一環として、必要なソフトウェアがプリインストールされたカスタム VM イメージを作成します。また、高速 RDMA ネットワークを構成し、Terraform の状態をホストする Cloud Storage バケットを作成します。
  2. クラスタがデプロイされたら、このチュートリアルに付属のスクリプト セットを使用して、分散ファインチューニング ジョブを実行します。このジョブでは、 Hugging Face Transformer Reinforcement Learning libraryを介してアクセスする PyTorch Fully Sharded Data Parallel(FSDP)を利用します。

このチュートリアルは、ファインチューニング ワークロードの処理に Slurm ジョブ スケジューリング機能を使用することを検討している ML エンジニア、プラットフォーム管理者、オペレーター、データおよび AI のスペシャリストを対象としています。

目標

  • Hugging Face を使用して Llama 4 にアクセスする。

  • 環境を準備する。

  • 本番環境グレードの A4 High-GPU Slurm クラスタを作成してデプロイする。

  • FSDP を使用した分散トレーニング用のマルチノード環境を構成する。

  • Hugging Face trl.SFTTrainer を使用して Llama 4 モデルをファインチューニングする。

  • データをローカル SSD にステージングする。

  • ジョブをモニタリングする。

  • クリーンアップする。

費用

このドキュメントでは、課金対象である次の コンポーネントを使用します Google Cloud:

料金計算ツールを使うと、予想使用量に基づいて費用の見積もりを生成できます。

新規の Google Cloud ユーザーは無料トライアルをご利用いただける場合があります。

始める前に

  1. アカウントにログインします。 Google Cloud を初めて使用する場合は、 アカウントを作成して、実際のシナリオで Google プロダクトのパフォーマンスを評価してください。 Google Cloud新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
  2. Google Cloud CLI をインストールします。

  3. 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  4. gcloud CLI を初期化するには、次のコマンドを実行します:

    gcloud init
  5. プロジェクトを作成または選択します Google Cloud

    プロジェクトを選択または作成するために必要なロール

    • プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトを選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、プロジェクト作成者ロール (roles/resourcemanager.projectCreator)が必要です。これには resourcemanager.projects.create 権限が含まれています。詳しくは、ロールを付与する方法をご覧ください。
    • プロジェクトを作成します。 Google Cloud

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  6. プロジェクトに対して課金が有効になっていることを確認します Google Cloud 。

  7. 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。serviceusage.services.enable詳しくは、ロールを付与する方法をご覧ください。

    gcloud services enable compute.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
  8. Google Cloud CLI をインストールします。

  9. 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  10. gcloud CLI を初期化するには、次のコマンドを実行します:

    gcloud init
  11. プロジェクトを作成または選択します Google Cloud

    プロジェクトを選択または作成するために必要なロール

    • プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトを選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、プロジェクト作成者ロール (roles/resourcemanager.projectCreator)が必要です。これには resourcemanager.projects.create 権限が含まれています。詳しくは、ロールを付与する方法をご覧ください。
    • プロジェクトを作成します。 Google Cloud

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  12. プロジェクトに対して課金が有効になっていることを確認します Google Cloud 。

  13. 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。serviceusage.services.enable詳しくは、ロールを付与する方法をご覧ください。

    gcloud services enable compute.googleapis.com logging.googleapis.com cloudresourcemanager.googleapis.com servicenetworking.googleapis.com
  14. ユーザー アカウントにロールを付与します。次の IAM ロールごとに次のコマンドを 1 回実行します。 roles/compute.admin, roles/iam.serviceAccountUser, roles/file.editor, roles/storage.admin, roles/serviceusage.serviceUsageAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    次のように置き換えます。

    • PROJECT_ID: プロジェクト ID。
    • USER_IDENTIFIER: ユーザー アカウントの識別子。 例: myemail@example.com
    • ROLE: ユーザー アカウントに付与する IAM ロール。
  15. プロジェクトのデフォルト サービス アカウントを有効にします: Google Cloud
    gcloud iam service-accounts enable PROJECT_NUMBER-compute@developer.gserviceaccount.com 
    --project=PROJECT_ID

    PROJECT_NUMBER は、プロジェクト番号に置き換えます。プロジェクト番号を確認するには、 既存のプロジェクトを取得するをご覧ください。

  16. 編集者ロール(roles/editor)をデフォルトのサービス アカウントに付与します。
    gcloud projects add-iam-policy-binding PROJECT_ID 
    --member="serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com"
    --role=roles/editor
  17. ユーザー アカウントのローカル認証情報を作成します。
    gcloud auth application-default login
  18. プロジェクトの OS Login を有効にします。
    gcloud compute project-info add-metadata --metadata=enable-oslogin=TRUE
  19. Hugging Face アカウントにログインするか、アカウントを作成します
  20. Cluster Toolkit の使用に必要な依存関係をインストールします。

Hugging Face を使用して Llama 4 にアクセスする

Hugging Face を使用して Llama 4 にアクセスする手順は次のとおりです。

  1. Llama 4 を使用するための同意契約に署名します

  2. Hugging Face read アクセス トークンを作成します

    Your Profile > Settings > Access tokens > +Create new token の順にクリックします。

  3. read access トークン値をコピーして保存します。このチュートリアルで後ほど使用します。

環境を準備する

環境を準備する手順は次のとおりです。

  1. Cluster Toolkit GitHub リポジトリのクローンを作成します。

    git clone https://github.com/GoogleCloudPlatform/cluster-toolkit.git
    
  2. Terraform の状態をホストする Cloud Storage バケットを作成します。

    gcloud storage buckets create gs://BUCKET_NAME \
        --project=PROJECT_ID
    

    次のように置き換えます。

    • BUCKET_NAME: バケット命名要件に準拠した Cloud Storage バケットの名前

    • PROJECT_ID: Cloud Storage バケットを作成するプロジェクトの ID。Google Cloud

A4 Slurm クラスタを作成する

A4 Slurm クラスタを作成する手順は次のとおりです。

  1. cluster-toolkit ディレクトリに移動します。

    cd cluster-toolkit
    
  2. Cluster Toolkit を初めて使用する場合は、gcluster バイナリをビルドします。

    make
    
  3. examples/machine-learning/a4-highgpu-8g ディレクトリに移動します。

    cd examples/machine-learning/a4-highgpu-8g/
    
  4. a4high-slurm-deployment.yaml ファイルを開き、次のように編集します。

    terraform_backend_defaults:
      type: gcs
      configuration:
        bucket: BUCKET_NAME
    
    vars:
      deployment_name: a4-high
      project_id: PROJECT_ID
      region: REGION
      zone: ZONE
      a4h_cluster_size: 2
      a4h_reservation_name: RESERVATION_URL
    

    次のように置き換えます。

    • BUCKET_NAME: 前のセクションで作成した Cloud Storage バケットの名前。

    • PROJECT_ID: バケットが存在し、 Google Cloud Slurm クラスタを作成するプロジェクトの ID。

    • REGION: 予約が存在するリージョン。

    • ZONE: 予約が存在するゾーン。

    • RESERVATION_URL: Slurm クラスタの作成に使用する予約の URL。予約が存在するプロジェクトに基づいて、次のいずれかの値を指定します。

      • 予約がプロジェクトに存在する場合: RESERVATION_NAME

      • 予約が別のプロジェクトに存在し、プロジェクトで予約を使用できる場合: 予約を使用できます: projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME

  5. クラスタをデプロイします。

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve
    

    ./gcluster deploy コマンドは、次の 2 つのフェーズで構成されるプロセスです。

    • 最初のフェーズでは、すべてのソフトウェアがプリインストールされたカスタム イメージをビルドします。完了までに最大 35 分かかることがあります。

    • 2 番目のフェーズでは、そのカスタム イメージを使用してクラスタをデプロイします。このプロセスは、最初のフェーズよりも早く完了します。

    最初のフェーズは成功したが、2 番目のフェーズが失敗した場合は、最初のフェーズをスキップして Slurm クラスタのデプロイを再試行できます。

    ./gcluster deploy -d examples/machine-learning/a4-highgpu-8g/a4high-slurm-deployment.yaml examples/machine-learning/a4-highgpu-8g/a4high-slurm-blueprint.yaml --auto-approve --skip "image" -w
    

ワークロードを準備する

ワークロードを準備する手順は次のとおりです。

  1. ワークロード スクリプトを作成します

  2. スクリプトを Slurm クラスタにアップロードします

  3. Slurm クラスタに接続します

  4. フレームワークとツールをインストールします

ワークロード スクリプトを作成する

ファインチューニング ワークロードで使用するスクリプトを作成する手順は次のとおりです。

  1. Python 仮想環境を設定するには、次の内容で install_environment.sh ファイルを作成します。

    #!/bin/bash
    # This script sets up a consistent environment for FSDP training.
    # It is meant to be run once on the login node of your Slurm cluster
    set -e
    
    # --- 1. Create the Python virtual environment ---
    VENV_PATH="$HOME/.venv/venv-fsdp"
    if [ ! -d "$VENV_PATH" ]; then
      echo "--- Creating Python virtual environment at $VENV_PATH ---"
      python3 -m venv $VENV_PATH
    else
      echo "--- Virtual environment already exists at $VENV_PATH ---"
    fi
    
    source $VENV_PATH/bin/activate
    
    # --- 2. Install Dependencies ---
    echo "--- [STEP 2.1] Upgrading build toolchain ---"
    pip install --upgrade pip wheel packaging
    
    echo "--- [STEP 2.2] Installing PyTorch Nightly ---"
    pip install --force-reinstall --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
    
    echo "--- [STEP 2.3] Installing application dependencies ---"
    if [ -f "requirements-fsdp.txt" ]; then
        pip install -r requirements-fsdp.txt
    else
        echo "ERROR: requirements-fsdp.txt not found!"
        exit 1
    fi
    
    # --- 3. Download the Model ---
    echo "--- [STEP 2.4] Downloading Llama4 model ---"
    if [ -z "$HF_TOKEN" ]; then
      echo "ERROR: The HF_TOKEN environment variable is not set."; exit 1;
    fi
    pip install huggingface_hub[cli]
    
    # Execute the CLI using its full, explicit path
    $VENV_PATH/bin/huggingface-cli download meta-llama/Llama-4-Scout-17B-16E-Instruct --local-dir ~/Llama-4-Scout-17B-16E-Instruct --token $HF_TOKEN
    
    echo "--- Environment setup complete. ---"
    

    このスクリプトは、信頼性の高い Python 仮想環境を設定し、PyTorch のナイトリー ビルドをインストールして、Llama 4 モデルをダウンロードします。

  2. トレーニング スクリプトの Python 依存関係を指定するには、次の内容で requirements-fsdp.txt ファイルを作成します。

    transformers==4.55.0
    datasets==4.0.0
    peft==0.16.0
    accelerate==1.9.0
    trl==0.21.0
    
    # Other dependencies
    sentencepiece==0.2.0
    
  3. メインのトレーニング スクリプトとして llama4-train-distributed.py を指定します。

    import torch
    from datasets import load_dataset
    from peft import LoraConfig, PeftModel
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        TrainingArguments,
        HfArgumentParser,
    )
    
    from torch.distributed import get_rank, get_world_size
    
    from transformers.models.llama4.modeling_llama4 import Llama4TextDecoderLayer
    from trl import SFTTrainer
    from dataclasses import dataclass, field
    from typing import Optional
    
    @dataclass
    class ScriptArguments:
        model_id: str = field(metadata={"help": "Hugging Face model ID from the Hub"})
        dataset_name: str = field(default="philschmid/gretel-synthetic-text-to-sql", metadata={"help": "Dataset from the Hub"})
        run_inference_after_training: bool = field(default=False, metadata={"help": "Run sample inference on rank 0 after training"})
        dataset_subset_size: Optional[int] = field(default=None, metadata={"help": "Number of samples to use from the dataset for training. If None, uses the full dataset."})
    
    @dataclass
    class PeftArguments:
        lora_r: int = field(default=16, metadata={"help": "LoRA attention dimension"})
        lora_alpha: int = field(default=32, metadata={"help": "LoRA alpha scaling factor"})
        lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout probability"})
    
    @dataclass
    class SftTrainingArguments(TrainingArguments):
        max_length: Optional[int] = field(default=2048, metadata={"help": "The maximum sequence length for SFTTrainer"})
        packing: Optional[bool] = field(default=False, metadata={"help": "Enable packing for SFTTrainer"})
        ddp_find_unused_parameters: Optional[bool] = field(default=True, metadata={"help": "When using FSDP activation checkpointing, this must be set to True"})
    
    def formatting_prompts_func(example):
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        user_prompt = f"### SCHEMA:\n{example['sql_context']}\n\n### USER QUERY:\n{example['sql_prompt']}"
        response = f"\n\n### SQL QUERY:\n{example['sql']}"
        return f"{system_message}\n\n{user_prompt}{response}"
    
    def main():
        parser = HfArgumentParser((ScriptArguments, PeftArguments, SftTrainingArguments))
        script_args, peft_args, training_args = parser.parse_args_into_dataclasses()
    
        training_args.gradient_checkpointing = True
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
    
        training_args.optim = "adamw_torch_fused"
    
        training_args.fsdp = "full_shard"
        training_args.fsdp_config = {
            "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
            "fsdp_transformer_layer_cls_to_wrap": [Llama4TextDecoderLayer],
            "fsdp_state_dict_type": "FULL_STATE_DICT",
            "fsdp_offload_params": False,
            "fsdp_forward_prefetch": True,
        }
    
        tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, trust_remote_code=True)
    
        model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            attn_implementation="sdpa",
        )
    
        peft_config = LoraConfig(
            r=peft_args.lora_r,
            lora_alpha=peft_args.lora_alpha,
            lora_dropout=peft_args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        )
        rank = get_rank()
        world_size = get_world_size()
    
        dataset = load_dataset(script_args.dataset_name, split="train")
    
        if script_args.dataset_subset_size is not None:
            dataset = dataset.select(range(script_args.dataset_subset_size))
        else:
            print(f"Using the full dataset with {len(dataset)} samples.")
    
        dataset = dataset.shuffle(seed=training_args.seed)
        print(f"Dataset shuffled with seed: {training_args.seed}.")
    
        if world_size > 1:
            print(f"Sharding dataset for Rank {rank} of {world_size}.")
            dataset = dataset.shard(num_shards=world_size, index=rank)
    
        print("Initializing SFTTrainer...")
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            peft_config=peft_config,
            formatting_func=formatting_prompts_func,
            processing_class=tokenizer,
        )
    
        trainer.train()
    
        trainer.save_model(training_args.output_dir)
    
        if script_args.run_inference_after_training and trainer.is_world_process_zero():
            del model
            del trainer
            torch.cuda.empty_cache()
            run_post_training_inference(script_args, training_args, tokenizer)
    
    def run_post_training_inference(script_args, training_args, tokenizer):
        """
        Loads the fine-tuned PEFT adapter from the local output directory and runs inference.
        This should only be called on rank 0 after training is complete.
        """
        print("\n" + "="*50)
        print("=== RUNNING POST-TRAINING INFERENCE TEST ===")
        print("="*50 + "\n")
    
        # Load the base model and merge the adapter.
        base_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        )
        # Load the PEFT adapter and merge it into the base model
        model = PeftModel.from_pretrained(base_model, training_args.output_dir)
        model = model.merge_and_unload() # Merge weights for faster inference
        model.eval()
    
        # Define the test case
        schema = "CREATE TABLE artists (Name TEXT, Country TEXT, Genre TEXT)"
        system_message = "You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."
        question = "Show me all artists from the Country just north of the USA."
    
        # This must match the formatting_func exactly
        prompt = f"{system_message}\n\n### SCHEMA:\n{schema}\n\n### USER QUERY:\n{question}\n\n### SQL QUERY:\n"
    
        print(f"Test Prompt:\n{prompt}")
    
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    
        print("\n--- Generating SQL... ---")
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False,
            temperature=None,
            top_p=None,
        )
    
        generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
    
        print(f"\n--- Generated SQL Query ---")
        print(generated_sql)
        print("\n" + "="*50)
        print("=== INFERENCE TEST COMPLETE ===")
        print("="*50 + "\n")
    
    if __name__ == "__main__":
        main()
    

    このスクリプトは、TRL 教師ありファインチューニング(SFT)Trainer を使用して、FSDP トレーニング ループ、Low-Rank Adaptation(LoRA)構成、データ形式を管理します。

  4. Slurm クラスタで実行するジョブのタスクを指定するには、次の内容で submit.slurm ファイルを作成します。

    #!/bin/bash
    #SBATCH --job-name=llama4-fsdp-fixed
    #SBATCH --nodes=2
    #SBATCH --ntasks-per-node=8
    #SBATCH --gpus-per-node=8
    #SBATCH --partition=a4high
    #SBATCH --output=llama4-%j.out
    #SBATCH --error=llama4-%j.err
    
    set -e
    set -x
    
    echo "--- Slurm Job Started ---"
    echo "Job ID: $SLURM_JOB_ID"
    echo "Node List: $SLURM_JOB_NODELIST"
    
    # --- Define Paths ---
    LOCAL_SSD_PATH="/mnt/localssd/job_${SLURM_JOB_ID}"
    VENV_PATH="${HOME}/.venv/venv-fsdp"
    MODEL_PATH="${HOME}/Llama-4-Scout-17B-16E-Instruct"
    
    # --- STAGE 1: Stage Data to Local SSD on Each Node ---
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "
      echo '--- Staging on node: $(hostname) ---'
    
      mkdir -p ${LOCAL_SSD_PATH}
    
      echo 'Copying virtual environment...'
      rsync -a -q ${VENV_PATH}/ ${LOCAL_SSD_PATH}/venv/
    
      echo 'Copying model weights...'
      rsync -a --info=progress2 ${MODEL_PATH}/ ${LOCAL_SSD_PATH}/model/
    
      mkdir -p ${LOCAL_SSD_PATH}/hf_cache
    
      echo '--- Staging on $(hostname) complete ---'
    "
    echo "--- Staging complete on all nodes ---"
    
    # --- STAGE 2: Run the Training Job ---
    echo "--- Launching Distributed Training with GIB NCCL Plugin ---"
    nodes=( $( scontrol show hostnames "$SLURM_JOB_NODELIST" ) )
    head_node=${nodes[0]}
    head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
    
    export MASTER_ADDR=$head_node_ip
    export MASTER_PORT=29500
    
    export NCCL_SOCKET_IFNAME=enp0s19
    export NCCL_NET=gIB
    
    # export NCCL_DEBUG=INFO # Un-comment to diagnose NCCL issues if needed
    
    srun --cpu-bind=none --accel-bind=g bash -c '
      # Activate the environment from the local copy
      source '${LOCAL_SSD_PATH}'/venv/bin/activate
    
      # Point Hugging Face cache to the local SSD
      export HF_HOME='${LOCAL_SSD_PATH}'/hf_cache
    
      export RANK=$SLURM_PROCID
      export WORLD_SIZE=$SLURM_NTASKS
      export LOCAL_RANK=$SLURM_LOCALID
    
      export LD_LIBRARY_PATH=/usr/local/gib/lib64:$LD_LIBRARY_PATH
      source /usr/local/gib/scripts/set_nccl_env.sh
    
      # --- Launch the training ---
      python \
        '${SLURM_SUBMIT_DIR}'/llama4-train-distributed.py \
          --model_id="'${LOCAL_SSD_PATH}'/model/" \
          --output_dir="'${LOCAL_SSD_PATH}'/outputs/" \
          --dataset_name="philschmid/gretel-synthetic-text-to-sql" \
          --seed=900913 \
          --bf16=True \
          --num_train_epochs=1 \
          --per_device_train_batch_size=2 \
          --gradient_accumulation_steps=4 \
          --learning_rate=2e-5 \
          --logging_steps=10 \
          --lora_r=16 \
          --lora_alpha=32 \
          --lora_dropout=0.05 \
          --run_inference_after_training
    '
    
    # --- STAGE 3: Copy Final Results Back to Persistent Storage ---
    echo "--- Copying final results from local SSD to shared storage ---"
    PERSISTENT_OUTPUT_DIR="${HOME}/outputs/llama4_job_${SLURM_JOB_ID}"
    mkdir -p "$PERSISTENT_OUTPUT_DIR"
    
    # Only copy from the head node where trl has combined the results
    srun --nodes=1 --ntasks=1 -w "$head_node" \
      rsync -a --info=progress2 "${LOCAL_SSD_PATH}/outputs/" "${PERSISTENT_OUTPUT_DIR}/"
    
    # --- STAGE 4: Cleanup ---
    echo "--- Cleaning up local SSD on all nodes ---"
    srun --ntasks=$SLURM_NNODES --ntasks-per-node=1 bash -c "rm -rf ${LOCAL_SSD_PATH}"
    
    echo "--- Slurm Job Finished ---"
    

スクリプトを Slurm クラスタにアップロードする

前のセクションで作成したスクリプトを Slurm クラスタにアップロードする手順は次のとおりです。

  1. ログインノードを特定するには、プロジェクト内のすべての A4 VM を一覧表示します。

    gcloud compute instances list --filter="machineType:a4-highgpu-8g"
    

    ログインノードの名前は a4-high-login-001 のようになります。

  2. スクリプトをログインノードのホーム ディレクトリにアップロードします。

    gcloud compute scp --project="$PROJECT_ID" --zone="$ZONE" --tunnel-through-iap \
      ./install_environment.sh \
      ./requirements-fsdp.txt \
      ./llama4-train-distributed.py \
      ./submit.slurm \
      "${LOGIN_NODE_NAME}":~/
    

    LOGIN_NODE_NAME は、ログインノードの名前に置き換えます。

Slurm クラスタに接続する

SSH を使用してログインノードに接続し、Slurm クラスタに接続します。

gcloud compute ssh LOGIN_NODE_NAME \
    --project=PROJECT_ID \
    --tunnel-through-iap \
    --zone=ZONE

フレームワークとツールをインストールする

ログインノードに接続したら、次の操作を行ってフレームワークとツールをインストールします。

  1. Hugging Face トークンをエクスポートします。

    # On the login node
    export HF_TOKEN="hf_..." # Replace with your token
    
  2. インストール スクリプトを実行します。

    # On the login node
    chmod +x install_environment.sh
    ./install_environment.sh
    

    このコマンドは、必要なすべての依存関係を含む仮想環境を設定し、モデルの重みを ~/Llama-4-Scout-17B-16E-Instruct ファイルにダウンロードします。

    モデルのダウンロードは非常に大きいため(約 200 GB)、ネットワークの状態によってはこのプロセスに 30 分ほどかかります。

ファインチューニング ワークロードを開始する

ワークロードのトレーニングを開始する手順は次のとおりです。

  1. ジョブを Slurm スケジューラに送信します。

    sbatch submit.slurm
    
  2. Slurm クラスタのログインノードで、home ディレクトリに作成された出力ファイルを確認して、ジョブの進行状況をモニタリングできます。

    # On the login node
    tail -f llama4-*.out
    

    ジョブが正常に開始されると、.err ファイルにプログレスバーが表示され、ジョブの進行状況に応じて更新されます。

    このジョブは、Slurm クラスタで完了するまでに 1 時間強かかります。このジョブには、次の 2 つのフェーズがあります。

    • 大規模なベースモデルを各コンピューティング ノードのローカル SSD にコピーする。
    • モデルのコピーが完了すると開始されるトレーニング ジョブ。 このジョブの実行には約 35 分かかります。

クリーンアップする

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

プロジェクトの削除

プロジェクトを削除します。 Google Cloud

gcloud projects delete PROJECT_ID

Slurm クラスタを削除する

Slurm クラスタを削除する手順は次のとおりです。

  1. cluster-toolkit ディレクトリに移動します。

  2. Terraform ファイルと作成したすべてのリソースを破棄します。

    ./gcluster destroy a4-high --auto-approve
    

次のステップ