まえがき
R&Dチームの宮﨑です。最近Fortnite熱が再燃して毎日練習してますが肝心な時にポンプを外してばかりでへこんでいます。 今回はPyTorch&TorchScriptで推論をFP16で実行し、速度計測やプロファイルしてみました。
- まえがき
- PyTorchとTorchScriptを用いてFP16で推論させる方法
- EC2(T4 Tensor Core GPU)でPyTorch/TorchScriptのFP32/FP16を速度計測&プロファイルしてみた
- 最後に
- おまけ: 計測コード
PyTorchとTorchScriptを用いてFP16で推論させる方法
PyTorchをFP16で推論するには基本的にmodelとinput tensorに対してhalf()で半精度化するだけです。簡単ですね。 半精度化したmodelとinput tensorを用いてtorch.jit.traceでTraceしてやると半精度のTorchScript Modelが手に入ります。注意点としては以下があります。
- modelの内部にある処理としてfloatのテンソル生成やキャストなどに起因してfloatのテンソルが混じり、半精度のmodelと食い違いエラーになることがあります。単純にエラーが発生している箇所からfloatのテンソルが発生している箇所を遡って半精度のテンソルに修正すれば良いです。
- modelの内部にあるC++拡張の部分が半精度に対応しておらずこけることもあります。そこだけfloatにしてあげるとエラーを解消できます。
- BatchNormalizationはFP16にしなければ性能が出ないという情報がありました。引用元ではBatchNormalizationだけfloat()にしているようです。
従ってPyTorchでFP16対応するコードはイメージとしては以下のようになります。
# modelを用意して半精度化 vgg16_fp32 = torchvision.models.vgg16(pretrained=True) vgg16_fp32.eval() vgg16_fp32 = vgg16_fp32.to(device) vgg16_fp16 = copy.deepcopy(vgg16_fp32).half() # input tensorを用意して半精度化 input_fp32 = input_fp32.to(device) input_fp16 = input_fp32.half() # 推論 with torch.no_grad(): output_fp16 = vgg16_fp16(input_fp16) output_fp32 = output_fp16.float().cpu()
EC2(T4 Tensor Core GPU)でPyTorch/TorchScriptのFP32/FP16を速度計測&プロファイルしてみた
- 今回は話を簡単にするためにBatchNormalizationのないVGG16をmodelとして使います。torchvisionから拝借しました。
- ImageNetの画像約700枚を使用して推論結果に差異がないことを確認しつつ計測を行いました。
- PyTorchのバージョンは1.4.0で、torchvisionのバージョンは0.5.0です。
速度計測
- PyTorch(の中で使われるCUDA)の速度計測で気をつけることとして下記の記事が非常に参考になります。PyTorchでは高速化のためにCUDAが使われますがGPU処理はCPU処理と非同期なので、速度計測する際はsynchronize()を用いて同期するかEventのrecord機能を用いて記録してやる必要があります。今回はsynchronize()を用いました。イメージとしては以下のようなコードになります。
with torch.no_grad(): # 計測開始 torch.cuda.synchronize() start_time = time.time() # 前処理 input = preprocess(data_path) # 推論 output = model(input) # 後処理 result = postprocess(output) # 計測終了 torch.cuda.synchronize() elapsed_time = time.time() - start_time
計測結果:単位(秒/枚)
FP32 | FP16 | |
---|---|---|
PyTorch | 0.0127 | 0.0091 |
TorchScript | 0.0127 | 0.0091 |
- PyTorchとTorchScriptでは速度に差異は出ず・・・他のモデルでもTorchScriptにした際速度に変化が見られる場合とほぼ変わらない場合があったのでTorchScriptによる高速化はモデルによるようです。
- Turing世代であるT4はテンソルコアを備え深層学習で用いられることが想定される精度(FP32/FP16/INT8/INT4)において性能を発揮するように設計されていますが確かにFP16にすることでFP32よりも約1.4倍の性能を発揮してくれました。
プロファイル
PyTorchとTorchScriptにおける推論処理のコアはCUDAで実行されるのでNVIDIAが提供するNsightを用いてプロファイルしました。
Nsightとは
Nsightの説明は公式のドキュメントが分かりやすいです。ざっと説明するとNsightはNsight Systems/Nsight Compute/Nsight Graphicsの三つに分かれています。Nsight Systemsでシステム全体の解析を行い、さらに詳細に解析を行いたい場合、CUDAカーネルレベルで解析を行いたいときはNsight Computeを、グラフィックスの解析を行いたいときはNsight Graphicsを用います。それぞれのドキュメントを参照して自分が調べたい情報を得るにはどれを用いれば良いのか確認してみると良いでしょう。今回はCUDAカーネルレベルで詳細を調べたいのでNsight Computeを用います。 従来はCUDA周りのプロファイラといえばVisual Profilerやnvprofがありましたが、最近のCUDAのドキュメントには以下の文があります。
Note that Visual Profiler and nvprof will be deprecated in a future CUDA release. The NVIDIA Volta platform is the last architecture on which these tools are fully supported
従って今回はTuring世代であるT4を用いるのでNsightを使うことになります。いずれにせよ従来のプロファイラは将来的に非推奨となるので機を見て移行した方が良さそうです。公式も移行ガイドを用意してくれているので参考にすると良いでしょう。
Nsightを使用するための前準備
Nsight ComputeのインストールはLinux環境では簡単でrunファイルを公式からダウンロードして実行するだけです(Nsight Systemsも同様)。ただし権限の設定が少し面倒で、例えばユーザ側の権限が足りない場合はNsight実行の際に以下のようなエラー文が出ます。
==ERROR== Error: ERR_NVGPUCTRPERM - The user does not have permission to access NVIDIA GPU Performance Counters on the target device 0. For instructions on enabling permissions and to get more information see https://developer.nvidia.com/ERR_NVGPUCTRPERM
対処するためには以下の情報を参考にしてください
- 基本的な対処はエラー文にも出ている公式のドキュメント通りに対応すれば良いです。管理者側とユーザ側で対応が必要なことに注意です。
- EC2ではこの管理者側の設定はあらかじ行われているのでユーザ側の対応のみで大丈夫です。例えばsudoでNsightを起動するだけで良いです。Dockerなど仮想環境のコンテナ内で実行する場合は権限の譲渡を行いNsightを実行できるようにします。
Nsightを用いたプロファイル&結果表示
プロファイルを行う際は以下のようなコマンドを実行します。ncuがNsight ComputeのCLIコマンドで、-oでプロファイル結果を格納するファイルの名前を指定、つづいてプロファイル対象のプログラム実行コマンドを記述します。
sudo /usr/local/cuda-11.1/bin/ncu -o profile_vgg16_fp32 python3 vgg16.py
- プロファイルはとても時間がかかるので対象の処理をできるだけ少な目にすると良いかもしれません。今回のプロファイルでは画像一枚だけを対象にしました。
結果を表示するには以下のようなコマンドを叩いてNsight ComputeのGUIを起動し、そこで生成したプロファイル結果を格納したファイルを開きます。
sudo /usr/local/NVIDIA-Nsight-Compute-2020.3/host/linux-desktop-glibc_2_11_3-x64/ncu-ui
左上のタブから様々な解析の表示を見れますが今回はRAWを選択します。すると以下のような画面になります。
PyTorchのFP16 TorchScriptのFP16
上記2つの結果を見てざっと分かることとして
- function nameにfp16が含まれた命令が実行されておりちゃんとFP16で実行されていることがわかります。
- PyTorchでもTorchScriptでもコアな部分で実行されている命令は同じものが並んでおり内部的には同じ処理が走っていることが分かります。上記の速度計測で結果がほぼ変わらなかったのもこれで納得できますね。
Nsightを使ったより便利に解析できるツールたち
DLProfやPyProfなど、内部でNsightを使い結果の可視化やテンソルコアの詳細な使用状況の取得などの機能を追加したツール群が提供されています。
最後に
今回はPyTorch&TorchScriptで推論をFP16で実行し、速度計測やプロファイルしてみました。 オプティムは半人前、単人前、倍人前なエンジニアを募集しています。
おまけ: 計測コード
実際に計測で使用したコードです。各関数を呼び出すだけのmain関数や結果の比較処理など本質でない部分は省略してあります。
def prepare_vgg16(device="cuda:0"): vgg16_fp32 = models.vgg16(pretrained=True) vgg16_fp32.eval() vgg16_fp32 = vgg16_fp32.to(device) vgg16_fp16 = copy.deepcopy(vgg16_fp32).half() input_fp32 = torch.rand(1, 3, 224, 224).to(device) input_fp16 = input_fp32.half() with torch.no_grad(): vgg16_fp32_ts = torch.jit.trace(vgg16_fp32, input_fp32) vgg16_fp16_ts = torch.jit.trace(vgg16_fp16, input_fp16) return vgg16_fp32, vgg16_fp16, vgg16_fp32_ts, vgg16_fp16_ts def prepare_input(data_path, device='cuda:0', half=False): img_org = Image.open(data_path) preprocess = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) img = preprocess(img_org) if half: img = img.half() input = torch.unsqueeze(img, 0) return input def bench_infer(model, data_path, device, half=False): with torch.no_grad(): # 計測開始 torch.cuda.synchronize() start_time = time.time() # 前処理 input = prepare_input(data_path, device, half) input = input.to(device) # 推論 output = model(input) output = output.float().cpu() # 計測終了 torch.cuda.synchronize() elapsed_time = time.time() - start_time return output, elapsed_time