# 機能設計書 106-TPUクロスレプリカ操作

## 概要

本ドキュメントは、TPUレプリカ間のAllReduce・AllToAll・CollectivePermuteなどのクロスレプリカ通信操作の設計を記述する。

### 本機能の処理概要

**業務上の目的・背景**：データ並列学習では、各レプリカが独立に計算した勾配を全レプリカで集約（AllReduce）する必要がある。また、モデル並列やパイプライン並列では、レプリカ間でテンソルの交換（AllToAll）やルーティング（CollectivePermute）が必要となる。本機能はTPU上でのこれらの通信操作を提供する。

**機能の利用シーン**：データ並列学習での勾配集約（CrossReplicaSum）、モデル並列でのテンソル分割・結合（AllToAll）、レプリカ間のデータルーティング（CollectivePermute）。

**主要な処理内容**：
1. AllToAll操作 - テンソルを分割してレプリカ間で全対全交換
2. CrossReplicaSum操作 - レプリカ間でテンソルの総和を計算
3. CollectivePermute操作 - 指定された送信元・送信先ペアに基づくテンソル転送

**関連システム・外部連携**：TPUインターコネクト（ICI: Inter-Core Interconnect）。

**権限による制御**：TPUデバイスへのアクセス権限が必要。

## 関連画面

| 画面No | 画面名 | 関連種別 | 関連する操作・処理 |
|--------|--------|----------|------------------|
| - | - | - | 本機能に関連する画面は登録されていない |

## 機能種別

計算処理（レプリカ間通信）

## 入力仕様

### 入力パラメータ

| パラメータ名 | 型 | 必須 | 説明 | バリデーション |
|-------------|-----|-----|------|---------------|
| input | T | Yes | 操作対象テンソル | AllToAll: numbertype+bool、CrossReplicaSum: half/bfloat16/float/float64/int32/uint32、CollectivePermute: numbertype |
| group_assignment | int32 | Yes（AllToAll, CrossReplicaSum） | レプリカグループ割り当て（2次元行列） | rank == 2 |
| source_target_pairs | int32 | Yes（CollectivePermute） | 送信元・送信先ペア | 2次元行列 |
| concat_dimension | int | Yes（AllToAll） | 結合次元 | 0以上、ランク未満 |
| split_dimension | int | Yes（AllToAll） | 分割次元 | 0以上、ランク未満 |
| split_count | int | Yes（AllToAll） | 分割数 | >= 1 |

### 入力データソース

TPUレプリカ上の計算結果テンソル。

## 出力仕様

### 出力データ

| 項目名 | 型 | 説明 |
|--------|-----|------|
| output (AllToAll) | T | 全対全交換後のテンソル（concat_dimension方向にsplit_count倍、split_dimension方向に1/split_count） |
| output (CrossReplicaSum) | T | レプリカ間総和テンソル（入力と同一形状） |
| output (CollectivePermute) | T | パーミュテーション後のテンソル（入力と同一形状） |

### 出力先

TPUレプリカ上の計算結果テンソル。

## 処理フロー

### 処理シーケンス

```
[AllToAll]
1. 入力テンソルをsplit_dimension方向にsplit_count個に分割
2. グループ内の各レプリカ間で断片を交換
3. 受信した断片をconcat_dimension方向に結合

[CrossReplicaSum]
1. グループ内の全レプリカの入力テンソルを要素ごとに加算
2. 結果を各レプリカに返却

[CollectivePermute]
1. source_target_pairsに基づいてテンソルをルーティング
2. 各レプリカは指定された送信元からテンソルを受信
```

### フローチャート

```mermaid
flowchart TD
    subgraph AllToAll
        A1[入力テンソル] --> A2[split_dimension方向に分割]
        A2 --> A3[レプリカ間で全対全交換]
        A3 --> A4[concat_dimension方向に結合]
    end
    subgraph CrossReplicaSum
        B1[入力テンソル] --> B2[グループ内AllReduce Sum]
        B2 --> B3[合計テンソル]
    end
    subgraph CollectivePermute
        C1[入力テンソル] --> C2[source_target_pairsでルーティング]
        C2 --> C3[ルーティング済みテンソル]
    end
```

## ビジネスルール

### 業務ルール

| ルールNo | ルール名 | 内容 | 適用条件 |
|---------|---------|------|---------|
| BR-106-01 | split_count整合性 | group_assignmentの第2次元サイズがsplit_countと一致する必要がある | AllToAll |
| BR-106-02 | 分割可能性 | split_dimension方向のテンソルサイズがsplit_countで割り切れる必要がある | AllToAll |
| BR-106-03 | group_assignment形状 | group_assignmentはランク2の行列である必要がある | AllToAll |
| BR-106-04 | 形状不変性 | CrossReplicaSumとCollectivePermuteは入力と同一形状の出力を返す | CrossReplicaSum, CollectivePermute |

### 計算ロジック

AllToAllの出力テンソル形状:
- concat_dimension方向: 入力サイズ * split_count
- split_dimension方向: 入力サイズ / split_count
- その他の次元: 変化なし

## データベース操作仕様

本機能はデータベース操作を行わない。

## エラー処理

### エラーケース一覧

| エラーコード | エラー種別 | 発生条件 | 対処方法 |
|------------|----------|---------|---------|
| InvalidArgument | バリデーションエラー | split_count < 1 | split_countを1以上に設定 |
| InvalidArgument | 形状不整合 | group_assignmentがランク2でない | 2次元行列を指定 |
| InvalidArgument | 次元範囲外 | concat_dimensionまたはsplit_dimensionが範囲外 | 有効な次元インデックスを指定 |
| InvalidArgument | 分割不能 | split_dimension方向のサイズがsplit_countで割り切れない | テンソルサイズまたはsplit_countを調整 |
| InvalidArgument | カウント不一致 | group_assignmentの第2次元 != split_count | 値を一致させる |

### リトライ仕様

自動リトライは行われない。

## トランザクション仕様

全OpsはSetIsStateful()が設定されている。クロスレプリカ操作は全レプリカが同時に実行される必要があり、1つのレプリカでもエラーが発生すると全体が影響を受ける。

## パフォーマンス要件

- TPUインターコネクト（ICI）の帯域幅に依存
- AllToAllはデータ量 * (split_count - 1) / split_count の通信量

## セキュリティ考慮事項

- TPUデバイスへのアクセス制御が必要

## 備考

- CrossReplicaSumはUnchangedShapeのShapeFnを使用（入出力形状同一）
- CollectivePermuteもUnchangedShape
- AllToAllは詳細なShapeFnで出力形状を計算（34-103行目）

---

## コードリーディングガイド

本機能を理解するために参照すべきファイルと、推奨する読み解き順序を以下に示す。

### 推奨読解順序

#### Step 1: Op登録を理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 1-1 | tpu_cross_replica_ops.cc | `tensorflow/core/ops/tpu_cross_replica_ops.cc` | 全3個のクロスレプリカOp登録定義 |

**主要処理フロー**:
1. **25-103行目**: AllToAll Op - 入力ランク確認、split_count検証、group_assignment形状検証、出力形状計算
2. **47-50行目**: split_count >= 1の検証
3. **51-53行目**: group_assignmentのランク2検証
4. **54-61行目**: group_assignment第2次元サイズとsplit_countの一致検証
5. **65-68行目**: concat_dimensionの範囲検証
6. **70-74行目**: split_dimensionの範囲検証
7. **85-98行目**: 出力形状計算（concat方向は倍増、split方向は縮小）
8. **91-96行目**: 分割可能性チェック
9. **105-111行目**: CrossReplicaSum Op - half/bfloat16/float/float64/int32/uint32型、UnchangedShape
10. **113-119行目**: CollectivePermute Op - numbertype、UnchangedShape

**読解のコツ**: AllToAllのShapeFnが最も複雑。入力テンソルの各次元について、concat_dimensionならsplit_count倍、split_dimensionなら1/split_countとなる出力形状を計算する。

### プログラム呼び出し階層図

```
AllToAll Op
    ├─ ShapeFn: split_count / group_assignment検証
    │      ├─ concat_dimension / split_dimension範囲チェック
    │      └─ 出力形状計算（concat方向×split_count、split方向÷split_count）
    └─ TPU ICI通信で全対全交換

CrossReplicaSum Op
    └─ ShapeFn: UnchangedShape（入力形状保持）

CollectivePermute Op
    └─ ShapeFn: UnchangedShape（入力形状保持）
```

### データフロー図

```
[AllToAll]
レプリカ0: [A0|A1|A2|A3] ──split──▶ A0,A1,A2,A3 ──exchange──▶ [A0|B0|C0|D0] concat
レプリカ1: [B0|B1|B2|B3] ──split──▶ B0,B1,B2,B3 ──exchange──▶ [A1|B1|C1|D1] concat
レプリカ2: [C0|C1|C2|C3] ──split──▶ C0,C1,C2,C3 ──exchange──▶ [A2|B2|C2|D2] concat
レプリカ3: [D0|D1|D2|D3] ──split──▶ D0,D1,D2,D3 ──exchange──▶ [A3|B3|C3|D3] concat

[CrossReplicaSum]
レプリカ0: X0 ─┐
レプリカ1: X1 ─┼─▶ Sum(X0,X1,...,XN) ──▶ 全レプリカに配信
レプリカN: XN ─┘
```

### 関連ファイル一覧

| ファイル | パス | 種別 | 役割 |
|---------|------|------|------|
| tpu_cross_replica_ops.cc | `tensorflow/core/ops/tpu_cross_replica_ops.cc` | ソース | クロスレプリカOp登録定義（3 Ops） |
