RareJob Tech Blog

レアジョブテクノロジーズのエンジニア・デザイナーによる技術ブログです

TorchScript入門 n番煎じ

こんにちは、EdTechLabの水谷です。

私の所属するEdTechLabで行なっている業務の一つとして、スピーキングテストPROGOSの自動採点を行う機械学習モデルをはじめ、各種機械学習モデルの開発から保守・運用までの一通りを担当しています。

今回は、一般にその中でも障壁が高いと言われる開発からデプロイの間のギャップに焦点を当て、人気のDeep LearningフレームワークであるPyTorchが提供するプロダクション環境向けモデルのTorchScriptについて解説していきたいと思います。

PyTorch

機械学習モデルの開発段階では試行錯誤・デバッグのし易さが実行速度よりも重要なことが多いです。そのため、Pythonのようなインタープリタ型で動的型付けの言語が以下の点で大きな利点があり、機械学習で広く使われています。

  • printやpdbで変数の中身を逐次確認でき、デバッグが容易
  • 少ない記述量で直感的にプログラミングできる

PythonDeep LearningフレームワークであるPyTorch

  • デフォルトではeager execution (define-by-run型)で、Tensorオブジェクトの中身が何かを即時にprint等で確認することができ、デバッグがしやすい
  • Pythonicな哲学で設計されており、またPythonのsyntaxをそのまま使うことができるため、学習コストが少ない上、直感的に記述することができる

などといったメリットがあり、人気のDeep Learningフレームワークとなっています。

一方で、define-by-runでは事前に計算グラフを構築しないためグラフのコンパイルによる処理の最適化が行えず、また動的型付けのためにメモリの最適化も行いづらいため計算量・メモリ消費量にオーバーヘッドがあり、eager modeのPytorchのモデルをプロダクションにデプロイするのは不向きと言えます。

TorchScript

試行錯誤が少なく、かつ推論速度やメモリ削減が要求されるプロダクションには、PyTorch (eager mode) やPythonの持つメリットよりも、静的型付けでコンパイル(define-and-run)型の最適化されたモデルの方が相応しそうです。

そのような特徴を持ったものがTorchScriptで、以下の特徴があります。

  1. 静的型付けでdefine-and-run型なので、計算グラフをコンパイルすることが可能で計算量・メモリの最適化が可能
  2. Pythonのランタイムに依存せず、C++のライブラリであるlibtorchから実行可能でモバイル/エッジデバイスにもデプロイでき、PythonのGlobal Interpreter Lockの制約も受けないため、スレッド並列化が可能

そもそもTorchScriptとは何なのか

TorchScriptという単語は少し乱用され気味で、公式ドキュメントにもさまざまな定義があります。

https://pytorch.org/docs/stable/jit.html#torchscript

TorchScript is a way to create serializable and optimizable models from PyTorch code

https://pytorch.org/docs/stable/jit.html#torchscript-language

TorchScript is a statically typed subset of Python

https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

TorchScript, an intermediate representation of a PyTorch model


TorchScript, our high-performance deployment runtime

それぞれ、PyTorchから変換する方法、Pythonのサブセット、IR、ランタイムなど、文脈によって異なるものを差しており、混乱の元になりかねないので、これ以降の本投稿では単にTorchScriptという言葉は極力避け、より限定的な単語を使って解説したいと思います。

TorchScript IR

PyTorchではtorch.nnモジュールを使ってモデルを構築します。それぞれのtorch.nnモジュールは内部にパラメータを持ち、テンソルを受け取り、テンソルを返すようなクラスですが、これらはC++の拡張モジュールとして実装されています。

つまり、実際のモデルの数値計算C++レベルで行われており、このPyTorchのコアと呼べるC++の自動微分ライブラリはATenと呼ばれています。

具体的にどのように呼ばれているかをnn.Linearソースコードを例にとって追っていくと、nn.Linearforward中では、torch.nn.funcitonal.linearという関数が呼ばれており、さらにtorch.nn.functional.linearの中身を確認すると、

torch._C._nn.linear

エイリアスとなっています。この定義はaten/src/ATen/native/Linear.cppにあり、実際の計算はC++のレベルで行われていることが分かります。(その後、バックエンド (e.g. cudnn, nnpack, mkldnn) に応じた適切なカーネルが呼ばれます。)

このように、torch.nnで構築したモデルの計算は全てATenで行われるので、計算グラフさえ分かれば理論的にはATenでモデルを構築する事が出来そうです。実際にPyTorchではモデルの計算グラフを表す中間表現(IR: Intermediate Representation)から独自のC++インタープリタを使ってモデルを構築する事ができ、この計算グラフを表す中間表現はTorchScript IRと呼ばれます。TorchScript IRが得られればPythonインタープリタに依存せずにモデルの構築ができ、また静的型付けになっているので効率的にコンパイラ最適化を行う事が出来ます。

TracingとScripting

PytorchのモデルからTorchScript IRを得る方法は以下の2通りがあります。

  1. Tracingを使った方法
  2. Scriptingを使った方法

Tracing

1の方法は明らかな方法で、PyTorchはdefine-by-runなので出力を得るまでグラフは分かりませんが、出力が得られれば、それまでに行われた処理の記録を辿ることで計算グラフを得ることが出来ます。

よって、PyTorchモデルと適当な入力テンソルさえあればTorchScript IRを得る事ができ、この方法はtracingと呼ばれます。

具体的には、torch.jit.trace関数にPytorchモデルインスタンスと入力テンソルを引数として渡すことで、torch.jit.ScriptModule(の子クラス)のモデルに変換できます。

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = torch.nn.Linear(4, 4)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.fc1(x))

pytorch_model = MyModel()
dummy_input = torch.rand(1, 4)
script_model = torch.jit.trace(pytorch_model, dummy_input)
print(isinstance(script_model, torch.jit.ScriptModule))  # True

また、変換されたscript_modelpytorch_modelの出力が一致していることのほか、script_modelでもauto-gradが使えていることが確認できます。

print(pytorch_model(dummy_input))  # tensor([[0.2628, 0.3168, 0.2951, 0.1253]], grad_fn=<SoftmaxBackward0>)
print(script_model(dummy_input))  # tensor([[0.2628, 0.3168, 0.2951, 0.1253]], grad_fn=<SoftmaxBackward0>)

TorchScript IRの情報も持っており、.graphプロパティでグラフを見る事ができます。

print(script_model.graph)

Out:

graph(%self.1 : __torch__.___torch_mangle_11.MyModel,
      %x : Float(1, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %softmax : __torch__.torch.nn.modules.activation.___torch_mangle_10.Softmax = prim::GetAttr[name="softmax"](%self.1)
  %fc1 : __torch__.torch.nn.modules.linear.___torch_mangle_9.Linear = prim::GetAttr[name="fc1"](%self.1)
  %24 : Tensor = prim::CallMethod[name="forward"](%fc1, %x)
  %25 : Tensor = prim::CallMethod[name="forward"](%softmax, %24)
  return (%25)

この方法は、殆ど何も手を加える事なくPyTorchモデルをTorchScript moduleに変換する事ができる一方で、特定のInputの場合の計算グラフしか得られず、control-flow (e.g. if-else) を含み計算グラフが動的に変化するようなモデルでは正しく動作しません。

例えば、以下のようなトイモデルを変換してみます。

class DecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

pytorch_model = DecisionGate()
traced_model = torch.jit.trace(pytorch_model, torch.ones(2))

traceの実行時にx = torch.Tensor([1.0, 1.0])を入力に使ったため、x.sum() > 0となり、そのままreturn xが実行されます。

tracingで変換されたモデルは、その時の計算グラフに固定されてしまいます。実際にtracingで変換されたモデルのIRをPythonシンタックスに直したコード (.codeプロパティ) を見てみると、条件分岐が実行されていない事が分かります。

print(traced_model.code)

Out:

def forward(self,
    x: Tensor) -> Tensor:
  return x

当然、入力の合計が負の場合はPytorchモデルとTracingで変換したモデルの出力は異なってしまいます。

input = -torch.ones(2)
print(pytorch_model(input))  # tensor([1., 1.])
print(traced_model(input))  # tensor([-1., -1.])

Scripting

計算グラフが動的に変わるようなcontrol-flowを含むモデルをTorchScript IRに変換する方法がscriptingによる変換です。

Scriptingでは計算グラフは介さず、Python(及びPyTorch)コードをパースして直接TorchScript IRに変換します。全てのPythonの型やシンタックスがサポートされている訳ではなく、Python一部の型・シンタックスで記述する静的型付け言語となっており、これがTorchScript Languageと呼ばれます。(これが最も多くTorchScriptと呼ばれています。)

いくつかの (単純な) control-flowもサポートされており、TorchScript Languageを使えば、control-flowもTorchScript IRに組み込む事が出来ます。

上述のDecisionGateを再び例に挙げて、scriptingで変換してみましょう。

If statementはサポートされているので、DecisionGateforwardメソッドはパースしてTorchScriptに変換する事が出来ます。

script_model = torch.jit.script(pytorch_model)
print(isinstance(script_model, torch.jit.ScriptModule))  # True

script_modelの方の.codeを確認してみましょう。

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

if-elseも反映されている事が分かります。script_modelの方は、先ほどの入力(-torch.ones(2))を入れた場合にも結果はPytorchモデルと一致します。

input = -torch.ones(2)
print(pytorch_model(input))  # tensor([1., 1.])
print(script_model(input))  # tensor([1., 1.])

TorchScript module

Tracingもしくはscriptingで得られたtorch.jit.ScriptModuleC++torch::jit::Moduleのラッパーになっており、.save()メソッドでTorchScript IRの情報をシリアライズして保存しておけばC++からも同様に呼び出す事が出来ます。

torch::jit::script::Module module = torch::jit::load("<path>")

JITコンパイル

TorchScript IRは静的型付けなので、ある程度のコンパイラ最適化(デッドコード除去, 定数伝搬, ループ展開などの一般的なコンパイラ最適化のほか、element-wiseな演算のfusion, 行列積のバッチ化などの数値計算に特化した最適化)はできますが、ニューラルネットワークの計算においてはまだ不十分な情報があります。

例えば、

x + y + z

という単純なelement-wiseな計算でも、x, y, zの形状が全て(1000,)の場合はfuseした方が計算量、メモリアクセスともに少なくなりますが、x, yが(3,)で、zが(3, 1000)の場合にはfuseしない方が計算量は小さくなります(メモリアクセスは同じ)。

このようにテンソルの形状、auto-gradが有効なのか等に依存して最適化が異なりますが、型付けの情報からでは分かりません。

ラッキーな事に、通常のニューラルネットワークのモデルではそれほど動的な処理は行なわれない事が多いので、JITコンパイル時にテンソルの形状やauto-gradの有無などの情報 (profile) を記録し、そのprofileを元に最適化 (profile guided optimization) を行います。profile guided optimizationの詳細は文献に譲りますが、最初の数サンプルでprofileの取得とJITコンパイルが行われ、その後はprofile通りであれば最適化されたコードが、そうでなければ最適化されていない元のコードが実行されます。

では、どの程度処理速度、メモリ消費が改善されるのかをhuggingfaceのBERTを使って実験してみましょう。

BERTで実験

huggingfaceのドキュメントにはTorchScriptへの変換に関して、以下のように記載されています。

Exporting a model requires two things:

  • a forward pass with dummy inputs.
  • model instantiation with the torchscript flag.

1点目は、tracingの際に必要になる入力だとすぐに分かります。

2点目に関しては、Encoder-Decoderモデルでは、input tokenのembeddingレイヤーとoutput tokenのembeddingレイヤーのweightが共有されていますが、TorchScriptではweight sharingができないないそうです。そのため、このフラグがある場合にはモデルをインスタンス化する際に呼ばれるメソッドでoutputのembeddingレイヤーにはinputのembeddingレイヤの重みをcloneしています。

ドキュメントにも記載がある通り、decoderが無いモデルでは当然torchscriptフラグは必要がなく、BERTはEncoderモデルなので、今回は1のみの適当な入力でtracingを実行するだけで良さそうです。

ちなみに、scriptingを使ったBERTの変換はgenerator式(self.modules()がgeneratorになっている)のところでUnsupportedNodeErrorとなり、変換ができませんでした。

pytorch_bert = BertModel.from_pretrained("bert-base-uncased")
torch.jit.script(pytorch_bert)

Out:

UnsupportedNodeError: GeneratorExp aren't supported:
  File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py", line 1538
        activations".
        """
        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
                   ~ <--- HERE

transformersのモデルがscriptingで変換できないという幾つかのissue(e.g. #5067)が上がっていますが、ライブラリの大幅な書き直しが必要なため、今のところ対応の見込みはないようです。

なので、今回はサポートされているtracingを使ってTorchScript moduleに変換しました。

Pytorch Profilerを使って実行速度、メモリ消費量を測定

colab環境で測定を行いました。GPUは不使用、測定に使う入力の長さは128、batch sizeは1としています。

token_length = 128
text = ' '.join(["test",] * (token_length-2))  # subtract [BOS] and [EOS]
inputs = tokenizer(text, return_tensors="pt")

Pytorch profilerのschedule設定は以下としました。

  • wait=5 : 最初の5ステップはprofilerをinactiveな状態でidlingする。
  • warmup=5 : profilerの起動直後はオーバーヘッドがあるので、その後5ステップは結果から除外する。
  • active=20 : その後の20ステップを実際の計測に使う。

※最初の数ステップはJITコンパイルが行われるため、オーバーヘッドが大きいのでwait/warmupで結果に含めないようにします。

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
    ],
    schedule=torch.profiler.schedule(
        wait=5,
        warmup=5,
        active=20,
    ),
    profile_memory=True,
) as prof:
    for _ in range(30):
        with torch.no_grad():
            pytorch_bert(**inputs)
        prof.step()

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

PyTorchモデル、TorchScriptモデルでの実行結果はそれぞれ以下のようになりました。CPU toalとSelf CPUの違いは、Self CPUが該当演算のみの実行時間なのに対し、CPU totalは該当演算が呼んでいる演算を全て足した実行時間になっています。

PyTorch Model:

---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
              ProfilerStep*         4.09%     477.156ms       100.00%       11.653s     582.656ms         -80 b      -2.41 Gb            20  
               aten::linear         0.42%      48.971ms        86.53%       10.084s       6.907ms     810.06 Mb           0 b          1460  
                aten::addmm        83.95%        9.783s        85.69%        9.985s       6.839ms     810.06 Mb     810.06 Mb          1460  
               aten::matmul         0.10%      11.675ms         2.84%     331.225ms     690.052us     270.00 Mb           0 b           480  
                  aten::bmm         2.61%     304.199ms         2.61%     304.653ms     634.694us     270.00 Mb     270.00 Mb           480  
                 aten::gelu         1.98%     230.406ms         1.98%     230.406ms     960.025us     360.00 Mb     360.00 Mb           240  
                aten::copy_         1.93%     225.005ms         1.93%     225.005ms     112.502us           4 b           4 b          2000  
           aten::layer_norm         0.03%       4.021ms         1.43%     166.564ms     333.128us     187.50 Mb    -493.50 Kb           500  
    aten::native_layer_norm         1.33%     155.240ms         1.39%     162.543ms     325.086us     187.98 Mb      -6.50 Kb           500  
              aten::softmax         0.02%       1.764ms         1.05%     122.851ms     511.879us     180.00 Mb           0 b           240  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 11.654s

TorchScript Model:

---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
              ProfilerStep*         0.05%       5.227ms       100.00%       10.656s     532.801ms         -80 b      -7.56 Mb            20  
                    forward         0.50%      53.352ms        99.95%       10.651s     532.539ms       7.56 Mb    -180.02 Mb            20  
               aten::linear         0.95%     100.736ms        89.45%        9.533s       6.529ms     360.06 Mb    -450.00 Mb          1460  
                aten::addmm        86.34%        9.201s        88.12%        9.390s       6.432ms     810.06 Mb     810.06 Mb          1460  
               aten::matmul         0.26%      28.168ms         3.12%     332.192ms     692.067us    -180.00 Mb    -450.00 Mb           480  
                  aten::bmm         2.72%     290.308ms         2.73%     290.623ms     605.465us     270.00 Mb     270.00 Mb           480  
                 aten::gelu         2.14%     227.959ms         2.14%     227.959ms     949.829us           0 b           0 b           240  
                aten::copy_         1.97%     209.491ms         1.97%     209.491ms     120.397us           0 b           0 b          1740  
              aten::softmax         0.45%      48.158ms         1.45%     155.002ms     645.842us           0 b    -180.00 Mb           240  
           aten::layer_norm         0.10%      10.234ms         1.45%     154.717ms     309.434us           0 b    -187.91 Mb           500  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 10.657s

PyTorchモデルでは1ステップ平均583msでしたが、TorchScriptモデルでは1ステップ平均533msになっており、8.58%の速度改善になっていました。

また、メモリ消費量に関しても、linear(1ステップの合計)で810 Mb -> 360 Mbに削減されている、matmulでメモリ解放が優位になっている、GELU, Softmax, LayerNormではメモリ消費が0 bになっているなど、こちらも最適化されている事が分かります。(メモリ消費量が負の値になっているものはメモリ解放を意味しているようです。)

まとめ

今回はTorchScriptの調査を行ない、変換の原理や変換時に行われている処理などについて深掘りしてみました。また、huggingfaceのBERTをTorchScriptモデルに変換し、処理速度とメモリ消費量削減の測定を行いました。

BERTのケースでは8%以上推論速度が改善されましたが、本番環境において推論速度はコストに直結することもあるので、8%以上のコスト削減をPytorchのモデルを1度関数にかけるだけで実現できるのは非常に良い体験だと思いました。

一方で、transformersのような大きな外部ライブラリで、動的な計算グラフになるモデルを利用する場合、自分でTorchScript Languageで書き直すのは大変そうなので、そのような場合は前もってscriptingで変換できるのか確認しておくと良いでしょう。

長くなりましたが、最後までお読みいただきありがとうございます。