Android でカスタム分類モデルを使用してオブジェクトを検出、追跡、分類する

ML Kit を使用すると、連続する動画フレーム内のオブジェクトを検出してトラックできます。

画像を ML Kit に渡すと、画像内の最大 5 つのオブジェクトと、各オブジェクトの画像内での位置が検出されます。動画ストリーム内のオブジェクトを検出する場合は、すべてのオブジェクトに一意の ID が割り当てられます。この ID を使用して、フレーム全体でオブジェクトをトラックできます。

カスタム画像分類モデルを使用して、検出されたオブジェクトを分類できます。モデルの互換性要件、事前トレーニング済みモデルの入手先、独自のモデルのトレーニング方法については、ML Kit のカスタムモデルをご覧ください。

カスタムモデルを統合する方法は 2 つあります。モデルをアプリのアセット フォルダに配置してバンドルする方法と、Cloud Storage から動的にダウンロードする方法があります。次の表に、2 つのオプションを比較します。

バンドルモデル ホストされているモデル
モデルはアプリの APK の一部であるため、サイズが大きくなります。 モデルは APK の一部ではありません。Cloud Storage にアップロードすることでホストされます。Cloud Storage for Firebase を使用することをおすすめします。
このモデルは、Android デバイスがオフラインのときでもすぐに利用可能 アプリには、モデルをオンデマンドでダウンロードするコードを含める必要があります
Firebase プロジェクトは不要 Firebase プロジェクトが必要(Cloud Storage for Firebase を使用する場合)。
モデルを更新するにはアプリを再公開する必要がある アプリを再公開することなくモデルの更新を push できる
組み込みの A/B テストはない Firebase Remote Config を使用した A/B テスト

試してみる

始める前に

1. プロジェクト レベルの build.gradle.kts ファイル内で、 buildscript セクションと allprojects セクションの両方に Google の Maven リポジトリを組み込みます。

  1. ML Kit Android ライブラリの依存関係をモジュールのアプリレベルの Gradle ファイル(通常は app/build.gradle.kts)に追加します。

    モデルをアプリにバンドルする場合:

    dependencies {
      // ...
      // Object detection & tracking feature with custom bundled model
      implementation("com.google.mlkit:object-detection-custom:17.0.2")
    }
    
  2. Cloud Storage for Firebase からモデルをダウンロードする場合は、 まだ行っていない場合は Firebase を Android プロジェクトに追加します。これは、モデルをバンドルする場合には必要ありません。

1. モデルを読み込む

モデルは、ローカル バンドルソースまたはリモート ホストソースから読み込むことができます。

ローカル モデルソースを構成する

モデルをアプリにバンドルするには:

  1. モデルファイル(拡張子は通常 .tflite または .lite)をアプリの assets/ フォルダにコピーします (先にこのフォルダを作成する必要がある場合があります。作成するには app/ フォルダを右クリックし、[新規] > [フォルダ] > Assets フォルダ をクリックします)。

  2. モデルファイルへのパスを指定して LocalModel オブジェクトを作成します。

    Kotlin

    val localModel = LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build()

    Java

    LocalModel localModel =
        new LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build();

リモート ホストモデルソースを構成する

リモート ホストモデルを使用するには、独自のアプリロジックを使用してモデルファイルをデバイスのローカル ストレージにダウンロードし、ローカルモデルとして読み込む必要があります。モデルをホストするには、Cloud Storage for Firebase を使用することをおすすめします。実装の詳細については、 Firebase ML から Cloud Storage への移行ガイドをご覧ください。

2. オブジェクト検出を構成する

モデルソースを構成したら、CustomObjectDetectorOptions オブジェクトを使用して、ユースケースにオブジェクト検出を構成します。次の設定を変更できます。

オブジェクト検出の設定
検出モード STREAM_MODE (デフォルト)| SINGLE_IMAGE_MODE

STREAM_MODE(デフォルト)では、オブジェクト検出機能は低レイテンシで実行されますが、最初の数回の検出機能の呼び出しで不完全な結果(未指定の境界ボックスやカテゴリラベルなど)が発生する可能性があります。また、STREAM_MODE、 検出機能でオブジェクトにトラッキング ID が割り当てられます。これを使用して、 フレームをまたいでオブジェクトをトラックできます。このモードは、オブジェクトをトラック する場合、または動画ストリームをリアルタイムで処理 する場合のように低レイテンシが重要な場合に使用します。

SINGLE_IMAGE_MODE では、オブジェクトの境界ボックスが決定された後に結果が返されます。分類も有効にすると、境界ボックスとカテゴリラベルの両方が使用可能になった後に結果が返されます。結果として、 検出のレイテンシが潜在的に長くなります。また、 SINGLE_IMAGE_MODE ではトラッキング ID が割り当てられません。レイテンシが重要ではなく、部分的な結果を処理しない場合は、このモードを使用します。

複数のオブジェクトを検出してトラックする false (デフォルト)| true

最大 5 つのオブジェクトを検出してトラックするか、最も 目立つオブジェクトのみをトラックするか(デフォルト)。

オブジェクトを分類する false (デフォルト)| true

提供された カスタム分類モデルを使用して、検出されたオブジェクトを分類するかどうか。カスタム分類 モデルを使用するには、これをtrueに設定する必要があります。

分類の信頼度のしきい値

検出されたラベルの最小信頼スコア。設定しない場合、モデルのメタデータで指定された分類子のしきい値が使用されます。モデルにメタデータが含まれていない場合、またはメタデータで 分類子のしきい値が指定されていない場合は、デフォルトのしきい値 0.0 が 使用されます。

オブジェクトあたりの最大ラベル数

検出器が返すオブジェクトあたりのラベルの最大数。 設定しない場合、デフォルト値の 10 が使用されます。

オブジェクトの検出とトラッキングの API は主に、次の 2 つのユースケース用に最適化されています。

  • カメラのビューファインダー内で最も目立つオブジェクトをライブで検出してトラッキングする。
  • 静止画像から複数のオブジェクトを検出する。

これらのユースケースに API を構成するには、ローカル バンドルモデルを使用します。

Kotlin

// Live detection and tracking
val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(localModel)
        .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE)
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

// Multiple object detection in static images
val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(localModel)
        .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
        .enableMultipleObjects()
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

val objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions)

Java

// Live detection and tracking
CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(localModel)
                .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE)
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

// Multiple object detection in static images
CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(localModel)
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableMultipleObjects()
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

ObjectDetector objectDetector =
    ObjectDetection.getClient(customObjectDetectorOptions);

リモートでホストされるモデルがある場合は、そのモデルを実行する前にダウンロード済みであることを確認する必要があります。

ダウンロードのステータスの確認は検出器を実行する前に行いますが、リモートでホストされるモデルとローカル バンドル モデルの両方を使用する場合は、画像検出器をインスタンス化するときにこの確認を行うという選択肢があります。この方法では、リモートモデルがダウンロードされている場合はリモートモデルから検出器を作成し、リモートモデルがダウンロードされていない場合はローカルモデルから検出器を作成するということが可能です。

Kotlin

val modelFile = File(context.cacheDir, "my_remote_model.tflite")

val model = if (modelFile.exists()) {
    // Use the downloaded model if available
    LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
} else {
    // Fall back to the bundled model
    LocalModel.Builder().setAssetFilePath("model.tflite").build()
}

val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(model)
        .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

val objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions)

Java

File modelFile = new File(context.getCacheDir(), "my_remote_model.tflite");

LocalModel model;
if (modelFile.exists()) {
    // Use the downloaded model if available
    model = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
} else {
    // Fall back to the bundled model
    model = new LocalModel.Builder().setAssetFilePath("model.tflite").build();
}

CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(model)
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

ObjectDetector objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions);

リモートでホストされるモデルのみがある場合は、モデルがダウンロード済みであることを確認するまで、モデルに関連する機能(UI の一部をグレー表示または非表示にするなど)を無効にする必要があります。

Kotlin

val localFile = File(context.cacheDir, "my_remote_model.tflite")
if (localFile.exists()) {
    // Model is already cached, initialize immediately
    initializeDetector(localFile)
} else {
    // Model is not yet available, show loading UI and start download
    showLoadingUI()
    val storage = Firebase.storage
    val modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite")
    modelRef.getFile(localFile)
        .addOnSuccessListener {
            // Download complete, initialize the detector
            hideLoadingUI()
            initializeDetector(localFile)
        }
        .addOnFailureListener {
            // Handle download error
            showErrorUI()
        }
}

private fun initializeDetector(modelFile: File) {
    val localModel = LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
    val customObjectDetectorOptions = CustomObjectDetectorOptions.Builder(localModel)
            .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
            .enableClassification()
            .build()
    val objectDetector = ObjectDetection.getClient(customObjectDetectorOptions)
    // Enable ML-related UI features here
    enableMLFeatures(objectDetector)
}

Java

File localFile = new File(context.getCacheDir(), "my_remote_model.tflite");
if (localFile.exists()) {
    // Model is already cached, initialize immediately
    initializeDetector(localFile);
} else {
    // Model is not yet available, show loading UI and start download
    showLoadingUI();
    FirebaseStorage storage = FirebaseStorage.getInstance();
    StorageReference modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite");
    modelRef.getFile(localFile)
        .addOnSuccessListener(new OnSuccessListener<FileDownloadTask.TaskSnapshot>() {
            @Override
            public void onSuccess(FileDownloadTask.TaskSnapshot taskSnapshot) {
                // Download complete, initialize the detector
                hideLoadingUI();
                initializeDetector(localFile);
            }
        })
        .addOnFailureListener(new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception exception) {
                // Handle download error
                showErrorUI();
            }
        });
}

private void initializeDetector(File modelFile) {
    LocalModel localModel = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
    CustomObjectDetectorOptions customObjectDetectorOptions =
            new CustomObjectDetectorOptions.Builder(localModel)
                    .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                    .enableClassification()
                    .build();
    ObjectDetector objectDetector = ObjectDetection.getClient(customObjectDetectorOptions);
    // Enable ML-related UI features here
    enableMLFeatures(objectDetector);
}

3. 入力画像を準備する

画像から InputImage オブジェクトを作成します。 オブジェクト検出は、Bitmap、NV21 ByteBuffer、YUV_420_888 media.Image から直接実行されます。これらのソースのいずれかに直接アクセスできる場合は、これらのソースから InputImage を構築することをおすすめします。他のソースから InputImage を構築する場合、変換は内部で処理されますが、効率が低下する可能性があります。

さまざまなソースから InputImage オブジェクトを作成できます。各ソースは次のとおりです。

media.Image の使用

InputImage オブジェクトから media.Image オブジェクトを作成するには(デバイスのカメラから画像をキャプチャする場合など)、media.Image オブジェクトと画像の回転を InputImage.fromMediaImage() に渡します。

CameraX ライブラリを使用する場合は、OnImageCapturedListenerImageAnalysis.Analyzer クラスによって回転値が計算されます。

Kotlin

private class YourImageAnalyzer : ImageAnalysis.Analyzer {

    override fun analyze(imageProxy: ImageProxy) {
        val mediaImage = imageProxy.image
        if (mediaImage != null) {
            val image = InputImage.fromMediaImage(mediaImage, imageProxy.imageInfo.rotationDegrees)
            // Pass image to an ML Kit Vision API
            // ...
        }
    }
}

Java

private class YourAnalyzer implements ImageAnalysis.Analyzer {

    @Override
    public void analyze(ImageProxy imageProxy) {
        Image mediaImage = imageProxy.getImage();
        if (mediaImage != null) {
          InputImage image =
                InputImage.fromMediaImage(mediaImage, imageProxy.getImageInfo().getRotationDegrees());
          // Pass image to an ML Kit Vision API
          // ...
        }
    }
}

画像の回転角度を取得するカメラ ライブラリを使用しない場合は、デバイスの回転角度とデバイス内のカメラセンサーの向きから計算できます。

Kotlin

private val ORIENTATIONS = SparseIntArray()

init {
    ORIENTATIONS.append(Surface.ROTATION_0, 0)
    ORIENTATIONS.append(Surface.ROTATION_90, 90)
    ORIENTATIONS.append(Surface.ROTATION_180, 180)
    ORIENTATIONS.append(Surface.ROTATION_270, 270)
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
@Throws(CameraAccessException::class)
private fun getRotationCompensation(cameraId: String, activity: Activity, isFrontFacing: Boolean): Int {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    val deviceRotation = activity.windowManager.defaultDisplay.rotation
    var rotationCompensation = ORIENTATIONS.get(deviceRotation)

    // Get the device's sensor orientation.
    val cameraManager = activity.getSystemService(CAMERA_SERVICE) as CameraManager
    val sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION)!!

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360
    }
    return rotationCompensation
}

Java

private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
static {
    ORIENTATIONS.append(Surface.ROTATION_0, 0);
    ORIENTATIONS.append(Surface.ROTATION_90, 90);
    ORIENTATIONS.append(Surface.ROTATION_180, 180);
    ORIENTATIONS.append(Surface.ROTATION_270, 270);
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
private int getRotationCompensation(String cameraId, Activity activity, boolean isFrontFacing)
        throws CameraAccessException {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    int deviceRotation = activity.getWindowManager().getDefaultDisplay().getRotation();
    int rotationCompensation = ORIENTATIONS.get(deviceRotation);

    // Get the device's sensor orientation.
    CameraManager cameraManager = (CameraManager) activity.getSystemService(CAMERA_SERVICE);
    int sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION);

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360;
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360;
    }
    return rotationCompensation;
}

次に、media.Image オブジェクトと 回転角度値を InputImage.fromMediaImage() に渡します。

Kotlin

val image = InputImage.fromMediaImage(mediaImage, rotation)

Java

InputImage image = InputImage.fromMediaImage(mediaImage, rotation);

ファイル URI の使用

InputImage オブジェクトをファイルの URI から作成するには、アプリ コンテキストとファイルの URI を InputImage.fromFilePath() に渡します。これは、 ACTION_GET_CONTENT インテントを使用して、写真アプリから画像を選択するようにユーザーに促すときに便利です。

Kotlin

val image: InputImage
try {
    image = InputImage.fromFilePath(context, uri)
} catch (e: IOException) {
    e.printStackTrace()
}

Java

InputImage image;
try {
    image = InputImage.fromFilePath(context, uri);
} catch (IOException e) {
    e.printStackTrace();
}

ByteBuffer または ByteArray の使用

InputImage オブジェクトを ByteBuffer または ByteArray から作成するには、media.Image 入力について上記のように、まず画像の 回転角度を計算します。 次に、画像の高さ、幅、カラー エンコード形式、回転角度とともに、バッファまたは配列を含む InputImage オブジェクトを作成します。

Kotlin

val image = InputImage.fromByteBuffer(
        byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)
// Or:
val image = InputImage.fromByteArray(
        byteArray,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)

Java

InputImage image = InputImage.fromByteBuffer(byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);
// Or:
InputImage image = InputImage.fromByteArray(
        byteArray,
        /* image width */480,
        /* image height */360,
        rotation,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);

Bitmap の使用

InputImage オブジェクトを Bitmap オブジェクトから作成するには、次の宣言を行います。

Kotlin

val image = InputImage.fromBitmap(bitmap, 0)

Java

InputImage image = InputImage.fromBitmap(bitmap, rotationDegree);

画像は Bitmap オブジェクトと回転角度で表されます。

4. オブジェクト検出器を実行する

Kotlin

objectDetector
    .process(image)
    .addOnFailureListener(e -> {...})
    .addOnSuccessListener(results -> {
        for (detectedObject in results) {
          // ...
        }
    });

Java

objectDetector
    .process(image)
    .addOnFailureListener(e -> {...})
    .addOnSuccessListener(results -> {
        for (DetectedObject detectedObject : results) {
          // ...
        }
    });

5. ラベル付きオブジェクトに関する情報を取得する

process() の呼び出しが成功すると、DetectedObject のリストが成功リスナーに渡されます。

DetectedObject には次のプロパティが含まれています。

境界ボックス 画像内のオブジェクトの位置を示す Rect
トラッキング ID イメージ間でオブジェクトを識別する整数。SINGLE_IMAGE_MODE では、NULL です。
ラベル
ラベルの説明 ラベルのテキスト説明。LiteRT モデルの メタデータにラベルの説明が含まれている場合にのみ返されます。
ラベル インデックス 分類子でサポートされているすべてのラベルにおけるラベルのインデックス。
ラベルの信頼度 オブジェクト分類の信頼値。

Kotlin

// The list of detected objects contains one item if multiple
// object detection wasn't enabled.
for (detectedObject in results) {
    val boundingBox = detectedObject.boundingBox
    val trackingId = detectedObject.trackingId
    for (label in detectedObject.labels) {
      val text = label.text
      val index = label.index
      val confidence = label.confidence
    }
}

Java

// The list of detected objects contains one item if multiple
// object detection wasn't enabled.
for (DetectedObject detectedObject : results) {
  Rect boundingBox = detectedObject.getBoundingBox();
  Integer trackingId = detectedObject.getTrackingId();
  for (Label label : detectedObject.getLabels()) {
    String text = label.getText();
    int index = label.getIndex();
    float confidence = label.getConfidence();
  }
}

優れたユーザー エクスペリエンスを確保する

最高のユーザー エクスペリエンスを提供するため、次のガイドラインに従ってアプリを作成してください。

  • オブジェクト検出の成功は、オブジェクトの視覚的な複雑さによります。視覚的特徴の少ないオブジェクトは、検出対象の画像の大部分を占めていないと検出に成功しない可能性があります。検出するオブジェクトの種類に適した入力をキャプチャするためのガイダンスを用意する必要があります。
  • 分類を使用するときに、サポート対象のカテゴリに該当しないオブジェクトを検出する場合は、未知のオブジェクトに対して特別な処理を実装してください。

また、 ML Kit マテリアル デザイン ショーケース アプリと Material Design の Patterns for machine learning-powered featuresのコレクションも確認してください。

パフォーマンスの向上

リアルタイムのアプリケーションでオブジェクト検出を使用する場合は、適切なフレームレートを得るために次のガイドラインに従ってください。

  • リアルタイム アプリケーションでストリーミング モードを使用する場合は、複数のオブジェクト検出を使用しないでください。ほとんどのデバイスは十分なフレームレートを生成できません。

  • Camera または camera2 API を使用する場合は、 検出機能への呼び出しをスロットリングします。検出器の実行中に新しい動画 フレームが使用可能になった場合は、そのフレームをドロップします。例については、クイックスタート サンプルアプリの VisionProcessorBase クラスをご覧ください。
  • CameraX API を使用する場合は、バックプレッシャー戦略がデフォルト値 ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST に設定されていることを確認してください。これにより、一度に 1 つの画像のみが分析用に配信されます。アナライザがビジー状態のときに画像が 生成された場合、それらの画像は自動的にドロップされ、配信のために キューに入れられることはありません。分析中の画像が ImageProxy.close() を呼び出して閉じられると、次の最新の画像が配信されます。
  • 検出器の出力を使用して入力画像の上にグラフィックスをオーバーレイする場合は、まず ML Kit から検出結果を取得し、画像とオーバーレイを 1 つのステップでレンダリングします。これにより、ディスプレイ サーフェスへのレンダリングは入力フレームごとに 1 回で済みます。例については、クイックスタート サンプルアプリの CameraSourcePreview クラスと GraphicOverlay クラスをご覧ください。
  • Camera2 API を使用する場合は、画像を ImageFormat.YUV_420_888 形式でキャプチャします。古い Camera API を使用する場合は、 ImageFormat.NV21 形式で画像をキャプチャします。