# 機能設計書 70-クロスデバイス通信

## 概要

本ドキュメントは、TensorFlowにおけるAllReduce、ReduceTo、Broadcastなどのクロスデバイス通信操作を提供する機能の設計を記述する。

### 本機能の処理概要

クロスデバイス通信機能は、tf.distribute.CrossDeviceOpsベースクラスとその具象実装（ReductionToOneDevice、NcclAllReduce、HierarchicalCopyAllReduce等）を通じて、複数デバイス間でのテンソル値のリダクション（集約）、ブロードキャスト、ギャザーなどの通信操作を提供する。分散ストラテジーの勾配集約やモデルパラメータ同期の基盤となる。

**業務上の目的・背景**：分散学習では、各デバイス（GPU/TPU）で計算された勾配やメトリクスを全デバイスで共有・集約する必要がある。クロスデバイス通信機能は、この集約処理を抽象化し、異なるハードウェア構成（単一マシン内GPU間、複数マシン間）やプロトコル（NCCL、gRPC等）に対応した効率的な通信を実現する。

**機能の利用シーン**：MirroredStrategyでの勾配All-Reduce、MultiWorkerMirroredStrategyでのワーカー間集約、パラメータの初期ブロードキャスト、メトリクスの集約。

**主要な処理内容**：
1. reduce: PerReplica値を指定デバイスにリダクション（SUM/MEAN）
2. batch_reduce: 複数のリダクションをバッチ化して効率的に実行
3. broadcast: テンソルを複数デバイスにブロードキャスト
4. gather: 各レプリカのテンソルを連結してギャザー
5. デスティネーション検証とデバイスマッチング

**関連システム・外部連携**：NCCL（GPU間高速通信）、collective通信、gRPC。

**権限による制御**：特になし。

## 関連画面

| 画面No | 画面名 | 関連種別 | 関連する操作・処理 |
|--------|--------|----------|------------------|
| - | - | - | 画面機能マッピングに該当なし |

## 機能種別

計算処理（デバイス間通信）

## 入力仕様

### 入力パラメータ

| パラメータ名 | 型 | 必須 | 説明 | バリデーション |
|-------------|-----|-----|------|---------------|
| reduce_op | ReduceOp | Yes(reduce時) | 集約操作の種類（SUM/MEAN） | ReduceOpインスタンス |
| per_replica_value | DistributedValues/Tensor | Yes | リダクション対象のPerReplica値 | PerReplicaに変換可能であること |
| destinations | DistributedValues/Variable/str | Yes | リダクション先のデバイス | 空でないこと |
| options | CommunicationOptions | No | 通信オプション | CommunicationOptionsインスタンス |
| axis (gather時) | int | Yes | ギャザーする軸 | - |

### 入力データソース

各デバイス上のテンソル値（PerReplica等のDistributedValues）。

## 出力仕様

### 出力データ

| 項目名 | 型 | 説明 |
|--------|-----|------|
| reduce結果 | Tensor/Mirrored | 集約後のテンソル値 |
| batch_reduce結果 | list[Tensor/Mirrored] | 複数の集約結果のリスト |
| broadcast結果 | Tensor/Mirrored | ブロードキャスト後のテンソル値 |
| gather結果 | Tensor/Mirrored | ギャザー後のテンソル値 |

### 出力先

指定されたデスティネーションデバイス上のテンソル。

## 処理フロー

### 処理シーケンス

```
1. 入力の検証と変換
   └─ 1-1. per_replica_value をPerReplicaオブジェクトに変換
   └─ 1-2. destinations の検証（型チェック、空チェック）
2. ショートカット判定
   └─ 単一レプリカかつデバイスが一致する場合、直接コピーで返却
3. 実際のリダクション実行
   └─ サブクラスの reduce_implementation を呼び出し
   └─ SUM: 全値を加算
   └─ MEAN: 全値を加算後、レプリカ数で除算
4. 結果のラッピング
   └─ Mirroredとしてregroup
```

### フローチャート

```mermaid
flowchart TD
    A[reduce呼び出し] --> B[入力をPerReplicaに変換]
    B --> C[destinationsの検証]
    C --> D{単一レプリカ?}
    D -->|Yes| E{デバイス一致?}
    E -->|Yes| F[identity コピーで返却]
    D -->|No| G[reduce_implementation 呼び出し]
    E -->|No| G
    G --> H[Mirroredとしてregroup]
    H --> I[結果返却]
```

## ビジネスルール

### 業務ルール

| ルールNo | ルール名 | 内容 | 適用条件 |
|---------|---------|------|---------|
| BR-70-01 | 空デスティネーション禁止 | destinationsは空であってはならない | 常時 |
| BR-70-02 | デスティネーション型制約 | DistributedValues、Variable、Tensor、文字列のいずれかでなければならない | 常時 |
| BR-70-03 | ReduceOp制約 | SUM またはMEANのみサポート | reduce時 |
| BR-70-04 | IndexedSlicesギャザー非サポート | gatherはIndexedSlicesを受け付けない | gather時 |
| BR-70-05 | クロスレプリカコンテキスト必須 | reduce/batch_reduce/broadcastはクロスレプリカコンテキストでのみ呼び出し可能 | 常時 |

### 計算ロジック

**_simple_reduce**:
```python
reduced = aggregate_tensors_or_indexed_slices(all_values, accumulation_fn)
if reduce_op == MEAN:
    reduced = reduced / count
```

## データベース操作仕様

### 操作別データベース影響一覧

データベース操作は行わない。

## エラー処理

### エラーケース一覧

| エラーコード | エラー種別 | 発生条件 | 対処方法 |
|------------|----------|---------|---------|
| ValueError | 空デスティネーション | destinationsが空 | 有効なデスティネーションを指定 |
| ValueError | 不正な型 | destinationsの型が不正 | DistributedValues、Variable、Tensor、文字列のいずれかを指定 |
| ValueError | 空PerReplica | per_replica_value.valuesが空 | 値を含むPerReplicaを指定 |
| NotImplementedError | 未実装メソッド | サブクラスがreduce_implementationを未実装 | 具象サブクラスを使用 |
| NotImplementedError | IndexedSlicesギャザー | gatherにIndexedSlicesを渡した | 通常のテンソルを使用 |

### リトライ仕様

通信エラー時のリトライはcollective通信実装に依存。

## トランザクション仕様

該当なし。All-Reduceは分散環境での同期ポイントとして機能する。

## パフォーマンス要件

- batch_reduceで複数のreduceをバッチ化して通信効率を向上
- 単一レプリカの場合はショートカットで通信をスキップ
- NCCLを使用した場合、GPU間の高帯域幅通信が可能

## セキュリティ考慮事項

デバイス間通信は同一プロセス内またはマシン間で行われる。マシン間通信のセキュリティは上位のストラテジーに依存。

## 備考

- `_canonicalize_devices`フラグでデバイス名の正規化を制御
- `_num_between_graph_workers`はデフォルト1で、サブクラスでオーバーライド可能

---

## コードリーディングガイド

本機能を理解するために参照すべきファイルと、推奨する読み解き順序を以下に示す。

### 推奨読解順序

#### Step 1: データ構造を理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 1-1 | cross_device_ops.py | `tensorflow/python/distribute/cross_device_ops.py` | ヘルパー関数群（52-216行目）：デスティネーション検証、値変換、単純リダクション |

**読解のコツ**: `PerReplica`はデバイスごとの値を保持するコンテナ、`Mirrored`は全デバイスで同じ値を持つことが保証されたPerReplicaのサブタイプ。`DistributedValues`はこれらの基底クラス。

#### Step 2: ヘルパー関数を理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 2-1 | cross_device_ops.py | `tensorflow/python/distribute/cross_device_ops.py` | `validate_destinations`関数（69-81行目）：デスティネーション型チェック |
| 2-2 | cross_device_ops.py | `tensorflow/python/distribute/cross_device_ops.py` | `simple_broadcast`関数（201-216行目）：単純ブロードキャスト |
| 2-3 | cross_device_ops.py | `tensorflow/python/distribute/cross_device_ops.py` | `_simple_reduce`関数（219-236行目）：単純リダクション |

**主要処理フロー**:
- **69-81行目**: destinationsの型検証（DistributedValues、Tensor、IndexedSlices、Variable、str、TPUMirroredVariable）
- **201-216行目**: 単一デバイスの場合はコピー、複数デバイスの場合はMirroredとしてregroup
- **219-236行目**: 全値を集約し、MEANの場合はレプリカ数で除算

#### Step 3: CrossDeviceOpsベースクラス

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 3-1 | cross_device_ops.py | `tensorflow/python/distribute/cross_device_ops.py` | `CrossDeviceOps`クラス（251-519行目） |

**主要処理フロー**:
- **266-268行目**: `__init__`で`_canonicalize_devices`をTrueに初期化
- **275-321行目**: `reduce`メソッド：入力変換、検証、ショートカット判定、reduce_implementation委譲
- **323-367行目**: `_gather`メソッド：ギャザー操作
- **400-447行目**: `batch_reduce`メソッド：バッチリダクション
- **449-466行目**: `broadcast`メソッド：ブロードキャスト

### プログラム呼び出し階層図

```
tf.distribute.CrossDeviceOps  [ベースクラス]
    │
    ├─ reduce(reduce_op, per_replica_value, destinations)
    │      ├─ _make_tensor_into_per_replica()  [入力変換]
    │      ├─ validate_destinations()  [検証]
    │      └─ reduce_implementation()  [サブクラス実装]
    │
    ├─ batch_reduce(reduce_op, value_destination_pairs)
    │      └─ batch_reduce_implementation()  [サブクラス実装]
    │
    ├─ broadcast(tensor, destinations)
    │      └─ broadcast_implementation()  [サブクラス実装]
    │
    └─ _gather(per_replica_value, destinations, axis)
           └─ _gather_implementation()  [サブクラス実装]

具象実装:
    ├─ tf.distribute.ReductionToOneDevice
    ├─ tf.distribute.NcclAllReduce
    └─ tf.distribute.HierarchicalCopyAllReduce
```

### データフロー図

```
[入力]                          [処理]                              [出力]

PerReplica values ───────────▶ reduce() ──▶ Tensor/Mirrored
  GPU0: grad_0                     │ (SUM or MEAN)
  GPU1: grad_1                     │
  GPU2: grad_2                     │

Tensor ──────────────────────▶ broadcast() ──▶ Mirrored
  (from one device)                │               GPU0: value
                                   │               GPU1: value
                                   │               GPU2: value

PerReplica values ───────────▶ _gather() ──▶ Tensor (concatenated)
  GPU0: [a, b]                     │
  GPU1: [c, d]                     │           result: [a, b, c, d]
```

### 関連ファイル一覧

| ファイル | パス | 種別 | 役割 |
|---------|------|------|------|
| cross_device_ops.py | `tensorflow/python/distribute/cross_device_ops.py` | ソース | CrossDeviceOpsベースクラスと具象実装 |
| cross_device_utils.py | `tensorflow/python/distribute/cross_device_utils.py` | ソース | テンソル集約・コピーユーティリティ |
| collective_util.py | `tensorflow/python/distribute/collective_util.py` | ソース | CommunicationOptionsとcollective設定 |
| reduce_util.py | `tensorflow/python/distribute/reduce_util.py` | ソース | ReduceOp定義（SUM、MEAN） |
| values.py | `tensorflow/python/distribute/values.py` | ソース | PerReplica、Mirrored等のDistributedValues |
| distribute_utils.py | `tensorflow/python/distribute/distribute_utils.py` | ソース | regroup等の分散ユーティリティ |
| device_util.py | `tensorflow/python/distribute/device_util.py` | ソース | デバイス名の解決・正規化 |
