# 機能設計書 39-カスタム勾配（Custom Gradient）

## 概要

本ドキュメントは、TensorFlowにおけるカスタム勾配機能の設計を記載する。`tf.custom_gradient` デコレータにより、関数のカスタム勾配関数を定義できる機構であり、`tensorflow/python/ops/custom_gradient.py` に実装される。

### 本機能の処理概要

**業務上の目的・背景**：標準の自動微分では数値的に不安定な勾配が得られる場合や、より効率的な勾配計算が可能な場合に、ユーザが独自の勾配関数を定義できる機構が必要となる。`tf.custom_gradient` デコレータはこの要件を満たす。

**機能の利用シーン**：
- 数値安定性の改善（例: log(1+exp(x)) の勾配）
- 計算効率の向上（例: 融合された順伝播と勾配計算）
- 近似勾配の提供（例: ストレートスルー推定器）
- 変数に対するカスタム勾配の定義

**主要な処理内容**：
1. `@tf.custom_gradient` デコレータを適用した関数は `(output, grad_fn)` タプルを返す
2. Eagerモードでは `_eager_mode_decorator` がGradientTapeでカスタム勾配を記録
3. グラフモードでは `_graph_mode_decorator` がgradient_override_mapを使用
4. 関数内でVariableを読む場合、grad_fnは `variables` パラメータも受け取る
5. Bindクラスにより、メソッドとしてもデコレータとして機能

**関連システム・外部連携**：GradientTape、OpRegistration、tf.functionとの統合。

**権限による制御**：特になし。

## 関連画面

| 画面No | 画面名 | 関連種別 | 関連する操作・処理 |
|--------|--------|----------|------------------|
| - | 該当なし | - | 画面機能マッピングに関連画面なし |

## 機能種別

計算処理（勾配カスタマイズ）

## 入力仕様

### custom_gradientデコレータ

| パラメータ名 | 型 | 必須 | 説明 | バリデーション |
|-------------|-----|-----|------|---------------|
| f | callable | Yes | `(y, grad_fn)` を返す関数。fは `f(*x)` の形式 | callableであること |

### デコレートされた関数の要件

```python
@tf.custom_gradient
def f(*inputs):
    # 順伝播計算
    outputs = ...
    def grad_fn(*upstream):
        # カスタム勾配計算
        return input_grads  # inputsと同数の勾配
    return outputs, grad_fn
```

### 変数を含む場合

```python
@tf.custom_gradient
def f(x):
    # 順伝播でVariableを使用
    y = weights * x
    def grad_fn(upstream, variables=None):
        grad_inputs = upstream * weights
        grad_vars = [upstream * x]
        return grad_inputs, grad_vars
    return y, grad_fn
```

## 出力仕様

### 出力データ

| 項目名 | 型 | 説明 |
|--------|-----|------|
| output | Tensor | 装飾された関数の順伝播出力（`f(x)[0]` と同等） |

### 出力先

通常の関数呼び出しと同様に返却。勾配計算時にはカスタム勾配関数が使用される。

## 処理フロー

### 処理シーケンス

```
1. @tf.custom_gradient が関数に適用される
2. Bind.decoratorでdecoratedラッパーを作成
3. 関数呼び出し時:
   a. Eagerモード: _eager_mode_decorator
      └─ 入力テンソルの準備、関数呼び出し
      └─ record.record_operation でカスタム勾配を登録
   b. グラフモード: _graph_mode_decorator
      └─ gradient_override_map を使用
      └─ RegisterGradient でカスタム勾配を登録
4. tape.gradient()呼び出し時:
   └─ カスタムgrad_fnが呼ばれ、ユーザ定義の勾配を返す
```

### フローチャート

```mermaid
flowchart TD
    A[@tf.custom_gradient 適用] --> B[Bind.decorator でラッパー作成]
    B --> C[関数呼び出し]
    C --> D{実行モード}
    D -->|Eager| E[_eager_mode_decorator]
    D -->|Graph| F[_graph_mode_decorator]
    E --> G[f(*inputs) -> outputs, grad_fn]
    F --> H[f(*inputs) -> outputs, grad_fn]
    G --> I[record.record_operation でgrad_fn登録]
    H --> J[gradient_override_map でgrad_fn登録]
    I --> K[outputs返却]
    J --> K
    K --> L[tape.gradient呼び出し]
    L --> M[カスタムgrad_fn実行]
    M --> N[カスタム勾配返却]
```

## ビジネスルール

### 業務ルール

| ルールNo | ルール名 | 内容 | 適用条件 |
|---------|---------|------|---------|
| BR-39-1 | 関数シグネチャ | デコレートされた関数は `(outputs, grad_fn)` タプルを返す必要がある | 全使用時 |
| BR-39-2 | 勾配関数シグネチャ | grad_fnは入力と同数の勾配を返す。Variableを読む場合は `(grad_inputs, grad_vars)` を返す | gradient計算時 |
| BR-39-3 | Variable検出 | 順伝播中にアクセスされたVariableはgrad_fnのvariables引数として渡される | Variable使用時 |
| BR-39-4 | None引数時のデコレータ動作 | f=Noneの場合、lambda f: custom_gradient(f=f) を返す（デコレータファクトリ） | @tf.custom_gradient() 形式 |
| BR-39-5 | メソッドバインディング | Bindクラスにより、クラスメソッドとしても正しく動作 | クラス内でのデコレータ使用時 |
| BR-39-6 | ネストされたcustom_gradient | ネストされた@tf.custom_gradientは直感的でない結果になる可能性がある。別関数でラップ推奨 | ネスト時 |

### 計算ロジック

カスタム勾配は連鎖律のVJP形式:
```
grad_fn(upstream) = upstream * d(output)/d(input)
```
upstreamは後段からの勾配。grad_fnはupstreamを受け取り、入力に対する勾配を返す。

## データベース操作仕様

### 操作別データベース影響一覧

| 操作 | 対象テーブル | 操作種別 | 概要 |
|-----|-------------|---------|------|
| - | - | - | データベース操作なし |

### テーブル別操作詳細

データベース操作は発生しない。

## エラー処理

### エラーケース一覧

| エラーコード | エラー種別 | 発生条件 | 対処方法 |
|------------|----------|---------|---------|
| TypeError | シグネチャ不正 | grad_fnのシグネチャがexpectedと一致しない | 正しいシグネチャで定義する |
| ValueError | 勾配数不一致 | grad_fnが返す勾配の数が入力の数と一致しない | 入力と同数の勾配を返す |
| LookupError | 未登録勾配 | tf.function内でカスタム勾配のないOpが使用された | tf.stop_gradientを使用 |

### リトライ仕様

リトライは不要。

## トランザクション仕様

特になし。

## パフォーマンス要件

- カスタム勾配自体のオーバーヘッドは最小
- Pythonでのgrad_fn実行はC++勾配関数と比べて低速
- tf.function内で使用する場合はトレース時にグラフに組み込まれ効率的

## セキュリティ考慮事項

特になし。

## 備考

- `tf.RegisterGradient` はプリミティブなTensorFlow Opの勾配関数を登録するのに対し、`tf.custom_gradient` は操作のシーケンスに対する勾配をカスタマイズする
- VAR_OP_TYPES = ["VariableV2", "VarHandleOp"] でVariable操作を識別

---

## コードリーディングガイド

本機能を理解するために参照すべきファイルと、推奨する読み解き順序を以下に示す。

### 推奨読解順序

#### Step 1: デコレータの構造を理解する

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 1-1 | custom_gradient.py | `tensorflow/python/ops/custom_gradient.py` | custom_gradient関数の全体構造 |

**主要処理フロー**:
1. **45-46行目**: `@tf_export("custom_gradient")` でエクスポート
2. **286-287行目**: `f is None` の場合、デコレータファクトリとして動作
3. **289-297行目**: `@Bind.decorator` で `decorated` ラッパーを作成
4. **292-295行目**: Eagerモード判定で `_eager_mode_decorator` / `_graph_mode_decorator` を呼び分け

#### Step 2: Bindクラスの仕組み

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 2-1 | custom_gradient.py | `tensorflow/python/ops/custom_gradient.py` | Bindクラス |

**主要処理フロー**:
- **300-339行目**: `Bind` クラス - `__get__` でインスタンスメソッドバインディングをサポート
- **323-325行目**: `@classmethod decorator` でデコレータを返す
- **331-334行目**: `__get__` でバインド済み関数を返す
- **338-339行目**: `__call__` で `self._d(self._f, a, k)` を実行

#### Step 3: Eagerモードの実装

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 3-1 | custom_gradient.py | `tensorflow/python/ops/custom_gradient.py` | _eager_mode_decorator |

**読解のコツ**: `_eager_mode_decorator` は `record.record_operation` を使用して、カスタム勾配をテープに記録する。`flat_result` と `all_tensors` からgrad_fnを構築し、Variable検出も行う。

#### Step 4: 変数処理

| 順序 | ファイル | パス | 読解ポイント |
|-----|---------|------|-------------|
| 4-1 | custom_gradient.py | `tensorflow/python/ops/custom_gradient.py` | get_variable_by_name, _get_dependent_variables |

**主要処理フロー**:
- **39-42行目**: `VAR_OP_TYPES` でVariable操作を識別
- **342-349行目**: `get_variable_by_name` でグローバル変数から取得

### プログラム呼び出し階層図

```
@tf.custom_gradient (custom_gradient.py)
    |
    +-- custom_gradient(f)
    |       |
    |       +-- f is None -> lambda f: custom_gradient(f=f) [デコレータファクトリ]
    |       |
    |       +-- @Bind.decorator -> decorated(wrapped, args, kwargs)
    |               |
    |               +-- [Eager] _eager_mode_decorator(wrapped, args, kwargs)
    |               |       +-- wrapped(*all_inputs) -> outputs, grad_fn
    |               |       +-- record.record_operation(grad_fn)
    |               |       +-- Variable検出と勾配ルーティング
    |               |
    |               +-- [Graph] _graph_mode_decorator(wrapped, args, kwargs)
    |                       +-- gradient_override_map
    |                       +-- RegisterGradient
    |
    +-- Bind.__get__(instance, owner) [メソッドバインディング]
    +-- Bind.__call__(*a, **k)
```

### データフロー図

```
[入力]                [処理]                      [出力]

入力テンソル -------> custom_gradient(f)
                          |
                     f(*inputs) -> (outputs, grad_fn)
                          |
                     record_operation(grad_fn)
                          |                             outputs
                     tape.gradient() 呼び出し時
                          |
                     grad_fn(*upstream)
                          |                             カスタム勾配
                     [Variable使用時]
                          |
                     grad_fn(*upstream, variables)
                          |                             (input_grads, var_grads)
```

### 関連ファイル一覧

| ファイル | パス | 種別 | 役割 |
|---------|------|------|------|
| custom_gradient.py | `tensorflow/python/ops/custom_gradient.py` | ソース | custom_gradientデコレータの主実装 |
| record.py | `tensorflow/python/eager/record.py` | ソース | record_operation（Eager勾配記録） |
| backprop.py | `tensorflow/python/eager/backprop.py` | ソース | GradientTape（勾配計算元） |
| ops.py | `tensorflow/python/framework/ops.py` | ソース | gradient_override_map（グラフモード） |
| tf_decorator.py | `tensorflow/python/util/tf_decorator.py` | ソース | make_decorator（関数装飾） |
| variable_scope.py | `tensorflow/python/ops/variable_scope.py` | ソース | Variable操作のスコープ管理 |
