TorchScript入門【PyTorchモデルを変換してC++/モバイルで実行できるモデルを手に入れる】

まえがき

R&Dチームの宮﨑です。今回はPyTorchのモデルをC++/モバイルから実行できるように変換してくれるTorchScriptについてご紹介したいと思います。本記事ではPyTorch 1.4.0を前提にしています。

8/26に発表するOPTiM TECH NIGHTでもLibTorch/TorchScriptについて話すのでぜひご覧になってください!

LibTorch/TorchScript概要

LibTorchはC++版のPyTorchです。学習・推論両方記述できます。 Pythonで書かれたPyTorchをC++で書かれたLibTorchで実装しなおすことで速度・汎用性の向上が望めます。 LibTorchのフォルダを設置し、CMakeで簡単にセットアップできるのも特徴です。

TorchScriptを用いることでPyTorchで学習させたモデル(もしくは単なるTensor処理)を最適化しつつ変換し、学習したパラメータごとファイルに保存できます。保存したファイルはLibTorchをインクルードしたC++環境であれば自由に読み込み・実行(推論)できます。さらにモバイル環境(iOS / Android)からもこのファイルを読み込み・実行(推論)させることができます。

TorchScript入門

基本機能

Trace

torch.jit.traceを用いることでPyTorchで記述したTensorの処理に対してサンプル入力を流し、その様子をTraceして最適化した上でTorchScript Modelに変換し、ptファイルとして保存できます。保存したptファイルはPyTorch/LibTorchで読み込み、実行させることができます。Python/PyTorchの多くの機能を対象にTraceできます。PyTorchで用意した学習済みモデルもTraceできます。

PyTorch: save

# Traceする対象のメソッドを記述
def test_sample(tensor):
    return torch.sum(tensor)

# Traceする際に用いるサンプル入力を用意
test_sample_input_trace = torch.tensor([1, 2, 3])
# TraceしてTorchScript Modelとして取得
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
# PyTorchの学習済みモデルもTraceできる
traced_from_pytorch_model = torch.jit.trace(torchvision.models.resnet18(), torch.rand(1, 3, 224, 224))

# TraceしたものをTorchScript ModelとしてSave
traced_from_pytorch_model.save("traced_from_pytorch_model.pt")
# PyTorch側でLoad / 実行できるのはもちろん・・・・
load_traced_from_pytorch_model = torch.jit.load("traced_from_pytorch_model.pt")
load_traced_from_pytorch_model(input_tensor)

LibTorch: load

// LibTorch(C++)でもTorchScript ModelをLoadできる!!!
torch::jit::script::Module module = torch::jit::load("traced_from_pytorch_model.pt");
// LibTorch側で用意したTensorを入力して実行
module.forward({input_tensor});

Script

@torch.jit.scriptのデコレータを付与したメソッドを記述することでTorchScriptを直接記述することもできます。上手くTraceできないときは直接記述してしまうのが楽です。TorchScriptは静的な型を持ったPythonのサブセットでほぼPythonと同じ感覚で記述できます。詳しくは下記の公式ドキュメントをご参照ください。

# TorchScriptを直接記述してしまう
@torch.jit.script
def test_sample(tensor):
    return torch.sum(tensor)

# TorchScriptの内容をPythonライクに表示
print(test_sample.code)
'''
def test_sample(tensor: Tensor) -> Tensor:
  return torch.sum(tensor, dtype=None)
'''
# TorchScriptの内容を内部graph表現で表示
print(test_sample.graph)
'''
graph(%tensor.1 : Tensor):
  %2 : None = prim::Constant()
  %3 : Tensor = aten::sum(%tensor.1, %2) # test.py:5:11
  return (%3)
'''

Tips

  • 生成されるTorchScript Modelはloop文などのcontrol flow演算子やその他のプリミティブな演算子で構成される静的な型の中間表現が採用されています。

  • TorchScript Modelへの入力及び出力の受け取りは後でC++で記述することになるため、TraceするPythonコードの入出力はC++でも表現しやすいようにしておくと後で楽になります。

  • TorchScript ModelはC++/LibTorchを用いて記述したメソッドを埋め込むことができます。そのためにはTraceする際にtorch::RegisterOperatorsなどで拡張するメソッドをTorchScriptに登録する必要があります。Python環境で登録したメソッドを使うためにはtorch.ops.load_library()などでライブラリをLoadする必要がありますが、既にそのライブラリを使用できる状態でLoadするとセグメンテーションフォルトを起こしたりするので注意が必要です。C++環境で実行する際は拡張コードをC++環境にも配置しRegisterOperatorsする必要があります。またLibTorchでカスタムレイヤを記述する際はfloat -> doubleint -> int64_tの書き換えが必要でした。型に制約があるようです。 詳しくは公式のドキュメントに分かりやすい説明があるのでご参照ください。

Traceの制約

Traceする際に用いるtorch.jit.traceはデータ非依存のモジュール・関数をTraceするだけであり、グローバル変数などの外的要因はTraceしません。要するにTrace時に実行させる、「サンプル入力を読み込んでの処理の様子」だけを基にTorchScriptを構成するということです。従って上手くTraceできない場合やPyTorch側で実行できてもLibTorch側でエラーが起こるケースもあります。Warningであらかじめ危険性を警告してくれる時もあります。いくつかの具体例はこの章でご紹介します。

筆者が勧める対策の方針

  1. Traceしたモデルがどこから想定外の挙動をしているのか突き止めます。筆者はいつもコード/実行結果(警告/エラーメッセージなど)を観察したり二分探索的に探しています。
  2. 想定外の挙動をしている箇所を特定したら次章に挙げるエラーケースと照らし合わせたり、公式のドキュメントを参照するなどしてどのようにTraceが対応できていないかを突き止めます。
  3. 原因に応じた対処をします。筆者が得た対処法は以下の通りです。具体例は次章でご紹介します。
    • Traceがサポートされていない記述はTraceできる記述に置き換えましょう
    • TorchScriptの制約を満たしていない記述は制約内の記述に置き換えましょう
    • @torch.jit.scriptをTraceするメソッドに付けてあげると多くの問題を解決します
      • TraceではなくTorchScriptを直接記述することになるため意図通りの処理を行わせやすい
      • 条件分岐をはじめとした多くの「Trace時のもので固定化」から解放されます
      • 場合分けができるのでテスト・デバッグに使えます
      • ただし記述できる内容に制約があるのでエラーが発生した際は制約内の記述に置き換えましょう
    • 型がうまくTraceされていない場合はmypyスタイルで型アノテーションをして型を明示してあげましょう
    • torch.jit.tracecheck_trace=Trueで元のPythonコードの結果と比較するテストをしてくれます。テストが通らなかった場合はエラーが出て処理が中止されますが、そのエラーメッセージで原因が分からないときはcheck_trace=Falseを設定しテストを無効にした上でTraceしてしまい、それをLoadして実行した際のエラーを見てみる、という手もあります。
    • TorchScript Modelのプロパティにあるgraphやcodeを表示させてどのように変換されてるかを見るのがかなり有効だと思います。
def test_sample(tensor):
    return torch.sum(tensor)

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace.graph)
'''
内部のグラフ表現を表示
graph(%tensor : Long(3)):
  %1 : None = prim::Constant()
  %2 : Long() = aten::sum(%tensor, %1) # test.py:4:0
  return (%2)
'''
print(test_sample_trace.code)
'''
Pythonライクに内部のグラフ表現を表示
def test_sample(tensor: Tensor) -> Tensor:
  return torch.sum(tensor, dtype=None)
'''

筆者が直面した具体的なTraceが上手くいかないケースとその対策

筆者が実際にTorchScriptでPyTorchモデルの変換を試みた際に上手く変換できなかったケースとその対策を列挙していきます。

TorchScript Modelは静的型付けなので他の型で再代入できないケース

def test_sample(tensor):
    tensor[0] = "string"
    return torch.sum(tensor)

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
TypeError: can't assign a str to a torch.LongTensor

対策: 一度代入した型と別の型で再代入しないようにする

autograd.Functionなど非対応のものはTraceできないケース

対策: Traceできるものに置き換える。例えばautograd.Functionは非対応なのでnn.Moduleで記述しなおすなど。

公式による非対応機能一覧

TensorではないただのPython valueはTrace時の値で固定化されるケース

def test_sample(tensor):
    tensor[0] = tensor[0].item()
    return tensor

test_sample_input = torch.tensor([1, 2, 3])
test_sample_input2 = torch.tensor([4, 5, 6])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input)
print(test_sample_trace(test_sample_input))
# => tensor([1, 2, 3])
print(test_sample_trace(test_sample_input2))
# => tensor([1, 5, 6])
test_stream.py:34: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  tensor[0] = tensor[0].item()

対策: Python valueは@torch.jit.scriptのデコレータをつけた関数内で扱う

@torch.jit.script
def script_len(tensor):
    return torch.tensor(tensor[0].item())

def test_sample(tensor):
    tensor[0] = script_len(tensor)
    return tensor

test_sample_input = torch.tensor([1, 2, 3])
test_sample_input2 = torch.tensor([4, 5, 6])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input)
print(test_sample_trace(test_sample_input))
# => tensor([1, 2, 3])
print(test_sample_trace(test_sample_input2))
# => tensor([4, 5, 6])

Pythonのboolean/ if文のflowがTrace時のもので固定化されるケース

def test_sample(tensor):
    if tensor[0]:
        tensor[0] = 1
        return tensor
    else:
        tensor[0] = 0
        return tensor

test_sample_input_trace = torch.tensor([1])
test_sample_input = test_sample_input_trace.clone()
test_sample_input2 = torch.tensor([0])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
# => tensor([1])
print(test_sample_trace(test_sample_input2))
# => tensor([1])
test_stream.py:6: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor[0]:

対策: booleanはPython valueなので@torch.jit.scriptのデコレータをつけた関数内で扱う

@torch.jit.script
def test_sample(tensor):
    if tensor[0]:
        tensor[0] = 1
        return tensor
    else:
        tensor[0] = 0
        return tensor

test_sample_input_trace = torch.tensor([1])
test_sample_input = test_sample_input_trace.clone()
test_sample_input2 = torch.tensor([0])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
# => tensor([1])
print(test_sample_trace(test_sample_input2))
# => tensor([0])

loop文がTrace時のflowで固定化されるケース

@torch.jit.script
def get_len(tensor):
    return torch.tensor(len(tensor))

def test_sample(tensor):
    end = get_len(tensor)
    tensor[1] = end
    for i in range(end):
        tensor[0] += 1
    return tensor

test_sample_input_trace = torch.tensor([1, 1])
test_sample_input = test_sample_input_trace.clone()
test_sample_input2 = torch.tensor([1, 1, 1])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
# => tensor([3, 2])
print(test_sample_trace(test_sample_input2))
# => tensor([3, 3, 1])
test.py:10: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  for i in range(end):

対策: @torch.jit.scriptのデコレータをつけた関数内で扱う

@torch.jit.script
def test_sample(tensor):
    end = len(tensor)
    tensor[1] = end
    for i in range(end):
        tensor[0] += 1
    return tensor

test_sample_input_trace = torch.tensor([1, 1])
test_sample_input = test_sample_input_trace.clone()
test_sample_input2 = torch.tensor([1, 1, 1])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
# => tensor([3, 2])
print(test_sample_trace(test_sample_input2))
# => tensor([4, 3, 1])

TensorのイテレーションはTrace時のもので固定化されるケース

def test_sample(tensor):
    for tensor_e in tensor:
        tensor[0] += 1
    return tensor

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_input = test_sample_input_trace.clone()
test_sample_input2 = torch.tensor([1, 2, 3, 4])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
# => tensor([4, 2, 3])
print(test_sample_trace(test_sample_input2))
# => tensor([4, 2, 3, 4])
/home/user/.local/lib/python3.6/site-packages/torch/tensor.py:461: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  'incorrect results).', category=RuntimeWarning)

対策: @torch.jit.scriptのデコレータをつけた関数内で扱う

@torch.jit.script
def test_sample(tensor):
    for tensor_e in tensor:
        tensor[0] += 1
    return tensor

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_input = test_sample_input_trace.clone()
test_sample_input2 = torch.tensor([1, 2, 3, 4])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
# => tensor([4, 2, 3])
print(test_sample_trace(test_sample_input2))
# => tensor([5, 2, 3, 4])

as_tensor(), torch.tensor()で生成したTensorの値は書き方によってはTrace時の値で固定化されるケース

def test_sample(tensor):
    return torch.tensor([torch.sum(tensor)])

test_sample_input = torch.tensor([1, 2, 3])
test_sample_input2 = torch.tensor([4, 5, 6])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input)
print(test_sample_trace(test_sample_input))
# => tensor([6]) 
print(test_sample_trace(test_sample_input2))
# => tensor([6])
test_stream.py:34: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return torch.tensor([torch.sum(tensor)])
test_stream.py:34: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return torch.tensor([torch.sum(tensor)])
def test_sample(tensor):
    return torch.as_tensor([torch.sum(tensor)])

test_sample_input = torch.tensor([1, 2, 3])
test_sample_input2 = torch.tensor([4, 5, 6])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input)
print(test_sample_trace(test_sample_input))
# => tensor([6]) 
print(test_sample_trace(test_sample_input2))
# => tensor([6]) 
test_stream.py:34: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return torch.as_tensor([torch.sum(tensor)])
test_stream.py:34: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return torch.as_tensor([torch.sum(tensor)])

対策: 以下の書き方で書く

def test_sample(tensor):
    return torch.tensor(torch.sum(tensor))

test_sample_input = torch.tensor([1, 2, 3])
test_sample_input2 = torch.tensor([4, 5, 6])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input)
print(test_sample_trace(test_sample_input))
# => tensor([6]) 
print(test_sample_trace(test_sample_input2))
# => tensor([15]) 
def test_sample(tensor):
    return torch.as_tensor(torch.sum(tensor))

test_sample_input = torch.tensor([1, 2, 3])
test_sample_input2 = torch.tensor([4, 5, 6])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input)
print(test_sample_trace(test_sample_input))
# => tensor([6]) 
print(test_sample_trace(test_sample_input2))
# => tensor([15]) 

import copyして用いるcopy.copy()では無警告で想定外の挙動をするケース

import torch
import copy

def test_sample(tensor):
    return copy.copy(tensor)

test_sample_input_trace = torch.tensor([1, 2, 3, 4])
test_sample_input = test_sample_input_trace.clone()
test_sample_input2 = torch.tensor([5, 6, 7, 8])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
# => tensor([1, 2, 3, 4])
print(test_sample_trace(test_sample_input2))
# => tensor([1, 2, 3, 4])

対策: copyの記述は避ける

デフォルトでは引数をTensorだと想定するためエラーを起こすケース

@torch.jit.script
def test_script(input):
    return torch.tensor(input)

def test_sample(input):
    input = test_script([1, 2, 3])
    return input

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
Traceback (most recent call last):
  File "test_stream.py", line 16, in <module>
    test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 906, in trace
    _force_outplace)
  File "test_stream.py", line 10, in test_sample
    input = test_script([1, 2, 3])
RuntimeError: test_script() Expected a value of type 'Tensor' for argument 'input' but instead found type 'list'.
Inferred 'input' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 0
Value: [1, 2, 3]
Declaration: test_script(Tensor input) -> (Tensor)
Cast error details: Unable to cast Python instance of type <class 'list'> to C++ type 'at::Tensor'

対策: mypyスタイルの型アノテーションで明示する

@torch.jit.script
def test_script(input):
    # type:(List[int]) -> Tensor
    return torch.tensor(input)

def test_sample(input):
    input = test_script([1, 2, 3])
    return input

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)

Python実行時とTrace時で型が異なる / Trace時に引数の型を必ずしも正確に推論しないケース

@torch.jit.script
def test_script(input):
    return torch.tensor([input])

def test_sample(input):
    input = input.size()[0]
    # inputが3: intになってほしいのにtensor(3): Tensor
    return test_script(input)

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
Traceback (most recent call last):
  File "test_stream.py", line 5, in <module>
    @torch.jit.script
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 1281, in script
    fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
RuntimeError:
Input list to torch.tensor must be of ints, floats, or bools, got Tensor:
  File "test_stream.py", line 7
@torch.jit.script
def test_script(input):
    return torch.tensor([input])
           ~~~~~~~~~~~~ <--- HERE

対策: mypyスタイルの型アノテーションで明示する

@torch.jit.script
def test_script(input):
    # type:(int) -> Tensor
    return torch.tensor([input])

def test_sample(input):
    input = input.size()[0]
    # inputが3: intが渡されてくれる
    return test_script(input)

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)

Tensorの伝搬を追跡しているため、print系や出力Tensorに影響しない記述はTraceされないケース

プロパティのcodeを表示させてみるとその辺りが垣間見える

def test_sample(tensor):
    print("(Д`; )")
    return tensor

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace.code)
# Trace時に何度か実行されるため出力されている
(Д`; )
(Д`; )
(Д`; )
# printの記述がない
def test_sample(argument_0: Tensor) -> Tensor:
  return argument_0

型の制限によるエラーを起こすケース

対応する型はバージョンによって変わったりするので詳しくはドキュメントをご参照ください

def test_sample(input_int):
    return torch.tensor([input_int])

test_sample_input_trace = [1, 2, 3]
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
  File "test.py", line 7, in <module>
    test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 906, in trace
    _force_outplace)
RuntimeError: Type 'Tuple[int, int, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced (toTraceableStack at /pytorch/torch/csrc/jit/pybind_utils.h:305)
def test_sample(input_tensor):
    return 1

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
print(test_sample_trace(test_sample_input))
Traceback (most recent call last):
  File "test_stream.py", line 9, in <module>
    test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
  File "/home/user/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 906, in trace
    _force_outplace)
RuntimeError: Only tensors or tuples of tensors can be output from traced functions (getOutput at /pytorch/torch/csrc/jit/tracer.cpp:212)
@torch.jit.script
def test_script(input):
    return 1

def test_sample(input):
    test_script(input)
    return input

test_sample_input_trace = torch.tensor([1, 2, 3])
test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
Traceback (most recent call last):
  File "test.py", line 12, in <module>
    test_sample_trace = torch.jit.trace(test_sample, test_sample_input_trace)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 906, in trace
    _force_outplace)
  File "test.py", line 8, in test_sample
    test_script(input)
RuntimeError: Tracer cannot set value trace for type Int. Supported types are tensor, tensor list, and tuple of tensors.

Trainingモード・EvalモードがTrace時のもので固定化されるケース

対策: TorchScriptを適用するのは基本的にEvalモードなのでTrainingモードでTraceしないようにだけ気をつける

最後に

今回はLibTorchの実装で用いるTorchScriptについてご紹介しました。PyTorchモデルをより汎用性が高く最適化されたTorchScriptにバンバン転換していきましょう!

オプティムは優秀なC++深層学習を組める若人さまを探す旅に出たい・・・