こんにちは、EdTechLabの水谷です。
私の所属するEdTechLabで行なっている業務の一つとして、スピーキングテストPROGOSの自動採点を行う機械学習モデルをはじめ、各種機械学習モデルの開発から保守・運用までの一通りを担当しています。
今回は、一般にその中でも障壁が高いと言われる開発からデプロイの間のギャップに焦点を当て、人気のDeep LearningフレームワークであるPyTorchが提供するプロダクション環境向けモデルのTorchScriptについて解説していきたいと思います。
PyTorch
機械学習モデルの開発段階では試行錯誤・デバッグのし易さが実行速度よりも重要なことが多いです。そのため、Pythonのようなインタープリタ型で動的型付けの言語が以下の点で大きな利点があり、機械学習で広く使われています。
PythonのDeep LearningフレームワークであるPyTorchは
- デフォルトではeager execution (define-by-run型)で、Tensorオブジェクトの中身が何かを即時にprint等で確認することができ、デバッグがしやすい
- Pythonicな哲学で設計されており、またPythonのsyntaxをそのまま使うことができるため、学習コストが少ない上、直感的に記述することができる
などといったメリットがあり、人気のDeep Learningフレームワークとなっています。
一方で、define-by-runでは事前に計算グラフを構築しないためグラフのコンパイルによる処理の最適化が行えず、また動的型付けのためにメモリの最適化も行いづらいため計算量・メモリ消費量にオーバーヘッドがあり、eager modeのPytorchのモデルをプロダクションにデプロイするのは不向きと言えます。
TorchScript
試行錯誤が少なく、かつ推論速度やメモリ削減が要求されるプロダクションには、PyTorch (eager mode) やPythonの持つメリットよりも、静的型付けでコンパイル(define-and-run)型の最適化されたモデルの方が相応しそうです。
そのような特徴を持ったものがTorchScriptで、以下の特徴があります。
- 静的型付けでdefine-and-run型なので、計算グラフをコンパイルすることが可能で計算量・メモリの最適化が可能
- Pythonのランタイムに依存せず、C++のライブラリである
libtorch
から実行可能でモバイル/エッジデバイスにもデプロイでき、PythonのGlobal Interpreter Lockの制約も受けないため、スレッド並列化が可能
そもそもTorchScriptとは何なのか
TorchScriptという単語は少し乱用され気味で、公式ドキュメントにもさまざまな定義があります。
https://pytorch.org/docs/stable/jit.html#torchscript
TorchScript is a way to create serializable and optimizable models from PyTorch code
https://pytorch.org/docs/stable/jit.html#torchscript-language
TorchScript is a statically typed subset of Python
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
TorchScript, an intermediate representation of a PyTorch model
TorchScript, our high-performance deployment runtime
それぞれ、PyTorchから変換する方法、Pythonのサブセット、IR、ランタイムなど、文脈によって異なるものを差しており、混乱の元になりかねないので、これ以降の本投稿では単にTorchScriptという言葉は極力避け、より限定的な単語を使って解説したいと思います。
TorchScript IR
PyTorchではtorch.nn
モジュールを使ってモデルを構築します。それぞれのtorch.nn
モジュールは内部にパラメータを持ち、テンソルを受け取り、テンソルを返すようなクラスですが、これらはC++の拡張モジュールとして実装されています。
つまり、実際のモデルの数値計算はC++レベルで行われており、このPyTorchのコアと呼べるC++の自動微分ライブラリはATenと呼ばれています。
具体的にどのように呼ばれているかをnn.Linear
のソースコードを例にとって追っていくと、nn.Linear
のforward
中では、torch.nn.funcitonal.linear
という関数が呼ばれており、さらにtorch.nn.functional.linear
の中身を確認すると、
torch._C._nn.linear
のエイリアスとなっています。この定義はaten/src/ATen/native/Linear.cpp
にあり、実際の計算はC++のレベルで行われていることが分かります。(その後、バックエンド (e.g. cudnn, nnpack, mkldnn) に応じた適切なカーネルが呼ばれます。)
このように、torch.nn
で構築したモデルの計算は全てATenで行われるので、計算グラフさえ分かれば理論的にはATenでモデルを構築する事が出来そうです。実際にPyTorchではモデルの計算グラフを表す中間表現(IR: Intermediate Representation)から独自のC++のインタープリタを使ってモデルを構築する事ができ、この計算グラフを表す中間表現はTorchScript IRと呼ばれます。TorchScript IRが得られればPythonのインタープリタに依存せずにモデルの構築ができ、また静的型付けになっているので効率的にコンパイラ最適化を行う事が出来ます。
TracingとScripting
PytorchのモデルからTorchScript IRを得る方法は以下の2通りがあります。
- Tracingを使った方法
- Scriptingを使った方法
Tracing
1の方法は明らかな方法で、PyTorchはdefine-by-runなので出力を得るまでグラフは分かりませんが、出力が得られれば、それまでに行われた処理の記録を辿ることで計算グラフを得ることが出来ます。
よって、PyTorchモデルと適当な入力テンソルさえあればTorchScript IRを得る事ができ、この方法はtracingと呼ばれます。
具体的には、torch.jit.trace
関数にPytorchモデルインスタンスと入力テンソルを引数として渡すことで、torch.jit.ScriptModule
(の子クラス)のモデルに変換できます。
class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.fc1 = torch.nn.Linear(4, 4) self.softmax = torch.nn.Softmax(dim=-1) def forward(self, x): return self.softmax(self.fc1(x)) pytorch_model = MyModel() dummy_input = torch.rand(1, 4) script_model = torch.jit.trace(pytorch_model, dummy_input) print(isinstance(script_model, torch.jit.ScriptModule)) # True
また、変換されたscript_model
とpytorch_model
の出力が一致していることのほか、script_model
でもauto-gradが使えていることが確認できます。
print(pytorch_model(dummy_input)) # tensor([[0.2628, 0.3168, 0.2951, 0.1253]], grad_fn=<SoftmaxBackward0>) print(script_model(dummy_input)) # tensor([[0.2628, 0.3168, 0.2951, 0.1253]], grad_fn=<SoftmaxBackward0>)
TorchScript IRの情報も持っており、.graph
プロパティでグラフを見る事ができます。
print(script_model.graph)
Out:
graph(%self.1 : __torch__.___torch_mangle_11.MyModel, %x : Float(1, 4, strides=[4, 1], requires_grad=0, device=cpu)): %softmax : __torch__.torch.nn.modules.activation.___torch_mangle_10.Softmax = prim::GetAttr[name="softmax"](%self.1) %fc1 : __torch__.torch.nn.modules.linear.___torch_mangle_9.Linear = prim::GetAttr[name="fc1"](%self.1) %24 : Tensor = prim::CallMethod[name="forward"](%fc1, %x) %25 : Tensor = prim::CallMethod[name="forward"](%softmax, %24) return (%25)
この方法は、殆ど何も手を加える事なくPyTorchモデルをTorchScript moduleに変換する事ができる一方で、特定のInputの場合の計算グラフしか得られず、control-flow (e.g. if-else) を含み計算グラフが動的に変化するようなモデルでは正しく動作しません。
例えば、以下のようなトイモデルを変換してみます。
class DecisionGate(torch.nn.Module): def forward(self, x): if x.sum() > 0: return x else: return -x pytorch_model = DecisionGate() traced_model = torch.jit.trace(pytorch_model, torch.ones(2))
traceの実行時にx = torch.Tensor([1.0, 1.0])
を入力に使ったため、x.sum() > 0
となり、そのままreturn x
が実行されます。
tracingで変換されたモデルは、その時の計算グラフに固定されてしまいます。実際にtracingで変換されたモデルのIRをPythonシンタックスに直したコード (.code
プロパティ) を見てみると、条件分岐が実行されていない事が分かります。
print(traced_model.code)
Out:
def forward(self, x: Tensor) -> Tensor: return x
当然、入力の合計が負の場合はPytorchモデルとTracingで変換したモデルの出力は異なってしまいます。
input = -torch.ones(2) print(pytorch_model(input)) # tensor([1., 1.]) print(traced_model(input)) # tensor([-1., -1.])
Scripting
計算グラフが動的に変わるようなcontrol-flowを含むモデルをTorchScript IRに変換する方法がscriptingによる変換です。
Scriptingでは計算グラフは介さず、Python(及びPyTorch)コードをパースして直接TorchScript IRに変換します。全てのPythonの型やシンタックスがサポートされている訳ではなく、Pythonの一部の型・シンタックスで記述する静的型付け言語となっており、これがTorchScript Languageと呼ばれます。(これが最も多くTorchScriptと呼ばれています。)
いくつかの (単純な) control-flowもサポートされており、TorchScript Languageを使えば、control-flowもTorchScript IRに組み込む事が出来ます。
上述のDecisionGate
を再び例に挙げて、scriptingで変換してみましょう。
If statementはサポートされているので、DecisionGate
のforward
メソッドはパースしてTorchScriptに変換する事が出来ます。
script_model = torch.jit.script(pytorch_model) print(isinstance(script_model, torch.jit.ScriptModule)) # True
script_model
の方の.code
を確認してみましょう。
def forward(self, x: Tensor) -> Tensor: if bool(torch.gt(torch.sum(x), 0)): _0 = x else: _0 = torch.neg(x) return _0
if-else
も反映されている事が分かります。script_model
の方は、先ほどの入力(-torch.ones(2)
)を入れた場合にも結果はPytorchモデルと一致します。
input = -torch.ones(2) print(pytorch_model(input)) # tensor([1., 1.]) print(script_model(input)) # tensor([1., 1.])
TorchScript module
Tracingもしくはscriptingで得られたtorch.jit.ScriptModule
はC++のtorch::jit::Module
のラッパーになっており、.save()
メソッドでTorchScript IRの情報をシリアライズして保存しておけばC++からも同様に呼び出す事が出来ます。
torch::jit::script::Module module = torch::jit::load("<path>")
JITコンパイル
TorchScript IRは静的型付けなので、ある程度のコンパイラ最適化(デッドコード除去, 定数伝搬, ループ展開などの一般的なコンパイラ最適化のほか、element-wiseな演算のfusion, 行列積のバッチ化などの数値計算に特化した最適化)はできますが、ニューラルネットワークの計算においてはまだ不十分な情報があります。
例えば、
x + y + z
という単純なelement-wiseな計算でも、x, y, zの形状が全て(1000,)の場合はfuseした方が計算量、メモリアクセスともに少なくなりますが、x, yが(3,)で、zが(3, 1000)の場合にはfuseしない方が計算量は小さくなります(メモリアクセスは同じ)。
このようにテンソルの形状、auto-gradが有効なのか等に依存して最適化が異なりますが、型付けの情報からでは分かりません。
ラッキーな事に、通常のニューラルネットワークのモデルではそれほど動的な処理は行なわれない事が多いので、JITコンパイル時にテンソルの形状やauto-gradの有無などの情報 (profile) を記録し、そのprofileを元に最適化 (profile guided optimization) を行います。profile guided optimizationの詳細は文献に譲りますが、最初の数サンプルでprofileの取得とJITコンパイルが行われ、その後はprofile通りであれば最適化されたコードが、そうでなければ最適化されていない元のコードが実行されます。
では、どの程度処理速度、メモリ消費が改善されるのかをhuggingfaceのBERTを使って実験してみましょう。
BERTで実験
huggingfaceのドキュメントにはTorchScriptへの変換に関して、以下のように記載されています。
Exporting a model requires two things:
- a forward pass with dummy inputs.
- model instantiation with the
torchscript
flag.
1点目は、tracingの際に必要になる入力だとすぐに分かります。
2点目に関しては、Encoder-Decoderモデルでは、input tokenのembeddingレイヤーとoutput tokenのembeddingレイヤーのweightが共有されていますが、TorchScriptではweight sharingができないないそうです。そのため、このフラグがある場合にはモデルをインスタンス化する際に呼ばれるメソッドでoutputのembeddingレイヤーにはinputのembeddingレイヤの重みをcloneしています。
ドキュメントにも記載がある通り、decoderが無いモデルでは当然torchscriptフラグは必要がなく、BERTはEncoderモデルなので、今回は1のみの適当な入力でtracingを実行するだけで良さそうです。
ちなみに、scriptingを使ったBERTの変換はgenerator式(self.modules()
がgeneratorになっている)のところでUnsupportedNodeError
となり、変換ができませんでした。
pytorch_bert = BertModel.from_pretrained("bert-base-uncased")
torch.jit.script(pytorch_bert)
Out:
UnsupportedNodeError: GeneratorExp aren't supported: File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py", line 1538 activations". """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) ~ <--- HERE
transformersのモデルがscriptingで変換できないという幾つかのissue(e.g. #5067)が上がっていますが、ライブラリの大幅な書き直しが必要なため、今のところ対応の見込みはないようです。
なので、今回はサポートされているtracingを使ってTorchScript moduleに変換しました。
Pytorch Profilerを使って実行速度、メモリ消費量を測定
colab環境で測定を行いました。GPUは不使用、測定に使う入力の長さは128、batch sizeは1としています。
token_length = 128 text = ' '.join(["test",] * (token_length-2)) # subtract [BOS] and [EOS] inputs = tokenizer(text, return_tensors="pt")
Pytorch profilerのschedule設定は以下としました。
wait=5
: 最初の5ステップはprofilerをinactiveな状態でidlingする。warmup=5
: profilerの起動直後はオーバーヘッドがあるので、その後5ステップは結果から除外する。active=20
: その後の20ステップを実際の計測に使う。
※最初の数ステップはJITコンパイルが行われるため、オーバーヘッドが大きいのでwait/warmupで結果に含めないようにします。
with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, ], schedule=torch.profiler.schedule( wait=5, warmup=5, active=20, ), profile_memory=True, ) as prof: for _ in range(30): with torch.no_grad(): pytorch_bert(**inputs) prof.step() print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
PyTorchモデル、TorchScriptモデルでの実行結果はそれぞれ以下のようになりました。CPU toalとSelf CPUの違いは、Self CPUが該当演算のみの実行時間なのに対し、CPU totalは該当演算が呼んでいる演算を全て足した実行時間になっています。
PyTorch Model:
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ProfilerStep* 4.09% 477.156ms 100.00% 11.653s 582.656ms -80 b -2.41 Gb 20 aten::linear 0.42% 48.971ms 86.53% 10.084s 6.907ms 810.06 Mb 0 b 1460 aten::addmm 83.95% 9.783s 85.69% 9.985s 6.839ms 810.06 Mb 810.06 Mb 1460 aten::matmul 0.10% 11.675ms 2.84% 331.225ms 690.052us 270.00 Mb 0 b 480 aten::bmm 2.61% 304.199ms 2.61% 304.653ms 634.694us 270.00 Mb 270.00 Mb 480 aten::gelu 1.98% 230.406ms 1.98% 230.406ms 960.025us 360.00 Mb 360.00 Mb 240 aten::copy_ 1.93% 225.005ms 1.93% 225.005ms 112.502us 4 b 4 b 2000 aten::layer_norm 0.03% 4.021ms 1.43% 166.564ms 333.128us 187.50 Mb -493.50 Kb 500 aten::native_layer_norm 1.33% 155.240ms 1.39% 162.543ms 325.086us 187.98 Mb -6.50 Kb 500 aten::softmax 0.02% 1.764ms 1.05% 122.851ms 511.879us 180.00 Mb 0 b 240 --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 11.654s
TorchScript Model:
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ProfilerStep* 0.05% 5.227ms 100.00% 10.656s 532.801ms -80 b -7.56 Mb 20 forward 0.50% 53.352ms 99.95% 10.651s 532.539ms 7.56 Mb -180.02 Mb 20 aten::linear 0.95% 100.736ms 89.45% 9.533s 6.529ms 360.06 Mb -450.00 Mb 1460 aten::addmm 86.34% 9.201s 88.12% 9.390s 6.432ms 810.06 Mb 810.06 Mb 1460 aten::matmul 0.26% 28.168ms 3.12% 332.192ms 692.067us -180.00 Mb -450.00 Mb 480 aten::bmm 2.72% 290.308ms 2.73% 290.623ms 605.465us 270.00 Mb 270.00 Mb 480 aten::gelu 2.14% 227.959ms 2.14% 227.959ms 949.829us 0 b 0 b 240 aten::copy_ 1.97% 209.491ms 1.97% 209.491ms 120.397us 0 b 0 b 1740 aten::softmax 0.45% 48.158ms 1.45% 155.002ms 645.842us 0 b -180.00 Mb 240 aten::layer_norm 0.10% 10.234ms 1.45% 154.717ms 309.434us 0 b -187.91 Mb 500 --------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 10.657s
PyTorchモデルでは1ステップ平均583msでしたが、TorchScriptモデルでは1ステップ平均533msになっており、8.58%の速度改善になっていました。
また、メモリ消費量に関しても、linear(1ステップの合計)で810 Mb -> 360 Mbに削減されている、matmulでメモリ解放が優位になっている、GELU, Softmax, LayerNormではメモリ消費が0 bになっているなど、こちらも最適化されている事が分かります。(メモリ消費量が負の値になっているものはメモリ解放を意味しているようです。)
まとめ
今回はTorchScriptの調査を行ない、変換の原理や変換時に行われている処理などについて深掘りしてみました。また、huggingfaceのBERTをTorchScriptモデルに変換し、処理速度とメモリ消費量削減の測定を行いました。
BERTのケースでは8%以上推論速度が改善されましたが、本番環境において推論速度はコストに直結することもあるので、8%以上のコスト削減をPytorchのモデルを1度関数にかけるだけで実現できるのは非常に良い体験だと思いました。
一方で、transformersのような大きな外部ライブラリで、動的な計算グラフになるモデルを利用する場合、自分でTorchScript Languageで書き直すのは大変そうなので、そのような場合は前もってscriptingで変換できるのか確認しておくと良いでしょう。
長くなりましたが、最後までお読みいただきありがとうございます。