PyTorch Metric LearningによるDeep Metric Learningの実践

こんにちは、R&Dチームの河野(@ps3kono)です。

前回は深層距離学習(Deep Metric Learning)の基礎知識とアルゴリズムの進化について紹介しましたが、この記事ではPyTorch Metric Learningという深層距離学習ライブラリを紹介したいと思います。

PyTorch Metric Learningについて

PyTorch Metric Learningはオープンソースライブラリであり(MIT License)、訓練・評価パイプラインに必要なコンポーネント(下図)がモジュール別で実装されたため、柔軟に組み合わせを変えられることで手軽に色々試すことができます。

f:id:optim-tech:20220331110639j:plain:w625:h408

  • Minerサンプル選択(Miner)は距離学習の重要な要素であり、モデル訓練の成功と収束性に大きく影響します。このモジュールの役割は、バッチ内から最適なサンプル組み合わせを選択することです(online miners)。easy、semi-hard、hard minersなどサンプル組み合わせを選択するためのTuple Minersの他には、ただバッチ内からサブセットサンプルを選択するためのSubset Batch Minersも実装されています。

  • Sampler:バッチを作成する際のサンプル選択する方法を指定するために利用されます(offline miners)。

  • Loss:損失関数を設定するためのモジュールです。Contrastive loss、Triplet loss、NT-Xent loss、SphereFace、CosFace、ArcFaceなど、様々な損失関数を選択できます。また、任意のminer, distance, regularizerを渡すことでカスタマイズすることが可能です。

  • Distance:サンプルの間の距離を算出するためのモジュールです。損失関数に合わせて、ユークリッド距離やコサイン類似度などいくつかの距離測定方法を選択することができます。

  • Regularizer:重みまたは特徴ベクトル(embedding)を正規化するためのモジュールです。

  • Reducer:ロスはサンプル毎、ペア毎または組み合わせ毎で計算されますが、reducerはその結果を一つの値にするためのモジュールです。PyTorchの損失関数のreductionと同様な役割を果たします。

  • Trainer:上記のlossとminerの他にsampler、optimizer、learning rate schedulerなどのパラメータを渡すことで、訓練を行ってくれます。また、end_of_iteration_hookまたはend_of_epoch_hookを設定することで訓練中のモデル精度評価を実施することが可能です。

  • Tester:テストデータに対する精度評価と可視化を行うためのモジュールです。accuracy_calculatorにカスタマイズされたAccuracyCalculatorを渡すことで任意の評価指標で精度評価をすることができます。また、visualizer_hookに可視化用の関数を渡すと、embedding空間を可視化することなどができます。下記の実装例をご参考してください。

  • AccuracyCalulator:モデル精度はk平均法(k-means)法とk近傍法(k-NN)を基づいて算出されます。AccuracyCalculatorクラスで評価指標の計算方法をカスタマイズすることができます。

  • HookContainer:end-of-iterationまたはend-of-epoch時に行う処理内容をカスタマイズするためのコンポーネントです。

利用可能な損失関数(loss)・距離測定方法(distance)・mining方法(miner)などはこちらをご参考してください。

Deep Metric Learningの実践:Triplet lossとArcFaceを比較

下記のソースコードは、公式リポジトリチュートリアルを参考して実装したものです。PyTorch Metric Learning v.1.0.0を利用し、Google Colabで実行しました。

まずは、必要なライブラリをインストールします。

pip install pytorch-metric-learning[with-hooks]
pip install umap-learn

今回は、FashionMNISTデータセットを利用します。この公開データセットは、MIT Licenseとなっています。

from torchvision import datasets, transforms

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

data_path = "dataset"  # 任意の場所
train_dataset = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(data_path, train=False, transform=transform)

このデータセットは、10クラスの衣料品の画像データが含まれて、画像のサンプル例は下記のとおりです。

f:id:optim-tech:20220331110631j:plain:w464:h347

訓練に利用したネットワーク構造は下記のとおりです。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 1, 1)
        self.conv3 = nn.Conv2d(64, 128, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.25)
        self.fc = nn.Linear(18432, 1152)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = F.relu(x)
        x = self.dropout2(x)
        return x

class Embedder(nn.Module):
    def __init__(self):
        super(Embedder, self).__init__()
        self.fc = nn.Linear(1152, 128)
     
    def forward(self, x):
        x = self.fc(x)
        return x

次は、訓練パイプラインに必要なコンポーネントを用意します。

  • Triplet lossの場合
from pytorch_metric_learning import losses, miners, distances, reducers, samplers

distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
miner = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="semihard")
sampler = samplers.MPerClassSampler(train_dataset.targets, m=4,
                                    length_before_new_iter=len(train_dataset))

loss_funcs = {"metric_loss": loss}
mining_funcs = {"tuple_miner": miner}
  • ArcFaceの場合
from pytorch_metric_learning import losses, distances, regularizers

distance = distances.CosineSimilarity()
regularizer = regularizers.RegularFaceRegularizer()
loss = losses.ArcFaceLoss(10, 128, margin=28.6, scale=64,
                          weight_regularizer=regularizer, distance=distance)
sampler = None

loss_funcs = {"metric_loss": loss}
mining_funcs = dict()

最後は、訓練と評価を行います。

import os
import logging
import numpy as np
import matplotlib.pyplot as plt
import umap
from cycler import cycler
from torch import optim
from pytorch_metric_learning import trainers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils import logging_presets


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# モデルをロード
trunk = Net()
trunk = torch.nn.DataParallel(trunk.to(device))
embedder = Embedder()
embedder = torch.nn.DataParallel(embedder.to(device))
models = {"trunk": trunk, "embedder": embedder}

# Optimizerの設定
trunk_optimizer = optim.Adam(trunk.parameters(), lr=0.005)
embedder_optimizer = optim.Adam(embedder.parameters(), lr=0.001)
optimizers = {"trunk_optimizer": trunk_optimizer,
              "embedder_optimizer": embedder_optimizer}

# 可視化用のvisual_hookの実装
record_keeper, _, _ = logging_presets.get_record_keeper("logs", "tensorboard")
hooks = logging_presets.get_hook_container(record_keeper)

def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, epoch):
    class_labels = np.unique(labels)
    num_classes = len(class_labels)
    
    fig = plt.figure(figsize=(8, 6))
    colors = [plt.cm.nipy_spectral(i) for i in np.linspace(0, 0.9, num_classes)]
    plt.gca().set_prop_cycle(cycler("color", colors))

    for i, lab in enumerate(class_labels):
        idx = labels == class_labels[i]
        plt.plot(umap_embeddings[idx, 0], umap_embeddings[idx, 1], ".", markersize=3, label=lab) 

    plt.legend(frameon=False, fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')
    os.makedirs("result", exist_ok=True)
    plt.savefig(f"result/{epoch:02d}.png")
    plt.show()
    plt.close()

# Testerの設定
tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook=hooks.end_of_testing_hook, 
                                            visualizer=umap.UMAP(), 
                                            visualizer_hook=visualizer_hook,
                                            dataloader_num_workers=4)

# Hookの設定
dataset_dict = {"val": test_dataset}
model_dir = "saved_models"
end_of_epoch_hook = hooks.end_of_epoch_hook(tester, 
                                            dataset_dict, 
                                            model_dir, 
                                            test_interval=1,
                                            patience=1)

# モデル訓練
num_epochs = 5
batch_size = 256

trainer = trainers.MetricLossOnly(models,
                                  optimizers,
                                  batch_size,
                                  loss_funcs,
                                  mining_funcs,
                                  train_dataset,
                                  sampler=sampler,
                                  dataloader_num_workers=4,
                                  end_of_iteration_hook=hooks.end_of_iteration_hook,
                                  end_of_epoch_hook=end_of_epoch_hook)
trainer.train(num_epochs=num_epochs)

Tripless lossとArcFaceで訓練した結果は下記のとおりです。ハイパーパラメータチューニングを行っていない状態ですが、下記のembedding空間の可視化結果を比較すると同程度の識別性能を得られたようにみられます。シャツ・Tシャツ・セーターなど上着系は識別しにくいようです。

f:id:optim-tech:20220331110648j:plain:w925:h327
※ 分かりやすくするため、上図では手動で凡例にクラス名を付けました。

本記事では、PyTorch Metric Learningについて紹介しました。深層距離学習を挑戦したい方は是非試してみてください。

おわりに

OPTiMでは、AIとIoT技術でビジネス課題・社会的課題を解決したい、チャレンジ精神・向上心を持っているエンジニアを積極的に募集してます。農業・医療・建設・産業分野などにおいてAIやIoTの活用で推進しているため、関心のある技術領域に挑戦するチャンスがあります。興味のある方は、こちらをご覧ください。 www.optim.co.jp

参考資料