在 Android 上使用自訂分類模型來偵測、追蹤及分類物件

您可以使用 ML Kit 偵測及追蹤連續影片影格中的物件。

將圖片傳遞至 ML Kit 時,系統會偵測圖片中最多五個物件,以及每個物件在圖片中的位置。偵測影片串流中的物件時,每個物件都有專屬 ID,可用於追蹤影格中的物件。

您可以使用自訂圖片分類模型,分類偵測到的物件。如需模型相容性規定、預先訓練模型取得位置,以及如何訓練自有模型的相關指引,請參閱「使用 ML Kit 的自訂模型」。

整合自訂模型的方法有兩種。您可以將模型放入應用程式的資產資料夾中,或從 Cloud Storage 動態下載模型。下表比較這兩個選項。

組合模式 代管模型
模型是應用程式 APK 的一部分,因此會增加 APK 大小。 模型並非 APK 的一部分,而是上傳至 Cloud Storage 後由該服務代管。我們建議使用 Cloud Storage for Firebase
即使 Android 裝置未連上網路,也能立即使用模型 應用程式必須包含程式碼,才能視需要下載模型
不需要 Firebase 專案 需要 Firebase 專案 (如果使用 Cloud Storage for Firebase)。
您必須重新發布應用程式,才能更新模型 無須重新發布應用程式,即可推送模型更新
沒有內建的 A/B 測試 使用 Firebase 遠端設定進行 A/B 測試

立即試用

事前準備

1. 在專案層級的 build.gradle.kts 檔案中,請務必在 buildscriptallprojects 區段中加入 Google 的 Maven 存放區。

  1. 將 ML Kit Android 程式庫的依附元件新增至模組的應用程式層級 Gradle 檔案,通常為 app/build.gradle.kts

    如要將模型與應用程式組合:

    dependencies {
      // ...
      // Object detection & tracking feature with custom bundled model
      implementation("com.google.mlkit:object-detection-custom:17.0.2")
    }
    
  2. 如要從 Cloud Storage for Firebase 下載模型,請務必將 Firebase 新增至 Android 專案 (如果尚未新增)。如果您是將模型與應用程式套件組合,則不需要這麼做。

1. 載入模型

您可以從本機組合來源或遠端代管來源載入模型。

設定本機模型來源

如要將模型與應用程式組合,請按照下列步驟操作:

  1. 將模型檔案 (通常以 .tflite.lite 結尾) 複製到應用程式的 assets/ 資料夾。(您可能需要先建立資料夾,方法是按一下滑鼠右鍵點選 app/ 資料夾,然後依序點選「New」>「Folder」>「Assets Folder」)。

  2. 建立 LocalModel 物件,並指定模型檔案的路徑:

    Kotlin

    val localModel = LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build()

    Java

    LocalModel localModel =
        new LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build();

設定遠端代管模型來源

如要使用遠端託管模型,您必須使用自己的應用程式邏輯,將模型檔案下載至裝置的本機儲存空間,然後載入為本機模型。建議使用 Cloud Storage for Firebase 託管模型。如需實作詳細資料,請參閱 Firebase ML 遷移至 Cloud Storage 指南

2. 設定物件偵測器

設定模型來源後,請使用 CustomObjectDetectorOptions 物件,為您的用途設定物件偵測器。您可以變更下列設定:

物件偵測工具設定
偵測模式 STREAM_MODE (預設) | SINGLE_IMAGE_MODE

STREAM_MODE (預設) 模式下,物件偵測器會以低延遲執行,但可能在前幾次呼叫偵測器時產生不完整的結果 (例如不明的定界框或類別標籤)。此外,在 STREAM_MODE 模式下,偵測器會為物件指派追蹤 ID,您可以使用這些 ID 追蹤影格中的物件。如要追蹤物件,或低延遲很重要 (例如即時處理影片串流時),請使用這個模式。

SINGLE_IMAGE_MODE 中,物件偵測器會在判斷物件的邊界框後傳回結果。如果您也啟用分類功能,偵測器會在邊界框和類別標籤都可用時傳回結果。因此,偵測延遲時間可能會較長。此外,在 SINGLE_IMAGE_MODE 中,系統不會指派追蹤 ID。如果延遲時間不重要,且您不想處理部分結果,請使用這個模式。

偵測及追蹤多個物件 false (預設) | true

是否要偵測及追蹤最多五個物件,或只追蹤最顯眼的物件 (預設)。

分類物件 false (預設) | true

是否使用提供的自訂敏感內容類別模型分類偵測到的物件。如要使用自訂分類模型,請將此值設為 true

分類可信度門檻

偵測到的標籤的最低信賴度分數。如未設定,系統會使用模型中繼資料指定的任何分類器門檻。如果模型不含任何中繼資料,或中繼資料未指定分類器門檻,系統會使用 0.0 的預設門檻。

每個物件的標籤數量上限

偵測器傳回的每個物件標籤數量上限。如未設定,系統會使用預設值 10。

物件偵測和追蹤 API 適用於下列兩項核心用途:

  • 即時偵測及追蹤相機觀景窗中最顯眼的物件。
  • 從靜態圖片偵測多個物件。

如要設定 API 以支援這些用途,並使用本機綁定的模型,請按照下列步驟操作:

Kotlin

// Live detection and tracking
val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(localModel)
        .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE)
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

// Multiple object detection in static images
val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(localModel)
        .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
        .enableMultipleObjects()
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

val objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions)

Java

// Live detection and tracking
CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(localModel)
                .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE)
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

// Multiple object detection in static images
CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(localModel)
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableMultipleObjects()
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

ObjectDetector objectDetector =
    ObjectDetection.getClient(customObjectDetectorOptions);

如果您有遠端代管模型,必須先檢查模型是否已下載,才能運作執行。

雖然您只需要在執行偵測器前確認這一點,但如果您同時有遠端代管模型和本機模型,在例項化圖片偵測器時執行這項檢查可能很有意義:如果已下載遠端模型,請從遠端模型建立偵測器,否則請從本機模型建立偵測器。

Kotlin

val modelFile = File(context.cacheDir, "my_remote_model.tflite")

val model = if (modelFile.exists()) {
    // Use the downloaded model if available
    LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
} else {
    // Fall back to the bundled model
    LocalModel.Builder().setAssetFilePath("model.tflite").build()
}

val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(model)
        .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

val objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions)

Java

File modelFile = new File(context.getCacheDir(), "my_remote_model.tflite");

LocalModel model;
if (modelFile.exists()) {
    // Use the downloaded model if available
    model = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
} else {
    // Fall back to the bundled model
    model = new LocalModel.Builder().setAssetFilePath("model.tflite").build();
}

CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(model)
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

ObjectDetector objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions);

如果只有遠端代管模型,您應停用模型相關功能 (例如將部分 UI 設為灰色或隱藏),直到確認模型已下載為止。

Kotlin

val localFile = File(context.cacheDir, "my_remote_model.tflite")
if (localFile.exists()) {
    // Model is already cached, initialize immediately
    initializeDetector(localFile)
} else {
    // Model is not yet available, show loading UI and start download
    showLoadingUI()
    val storage = Firebase.storage
    val modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite")
    modelRef.getFile(localFile)
        .addOnSuccessListener {
            // Download complete, initialize the detector
            hideLoadingUI()
            initializeDetector(localFile)
        }
        .addOnFailureListener {
            // Handle download error
            showErrorUI()
        }
}

private fun initializeDetector(modelFile: File) {
    val localModel = LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
    val customObjectDetectorOptions = CustomObjectDetectorOptions.Builder(localModel)
            .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
            .enableClassification()
            .build()
    val objectDetector = ObjectDetection.getClient(customObjectDetectorOptions)
    // Enable ML-related UI features here
    enableMLFeatures(objectDetector)
}

Java

File localFile = new File(context.getCacheDir(), "my_remote_model.tflite");
if (localFile.exists()) {
    // Model is already cached, initialize immediately
    initializeDetector(localFile);
} else {
    // Model is not yet available, show loading UI and start download
    showLoadingUI();
    FirebaseStorage storage = FirebaseStorage.getInstance();
    StorageReference modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite");
    modelRef.getFile(localFile)
        .addOnSuccessListener(new OnSuccessListener<FileDownloadTask.TaskSnapshot>() {
            @Override
            public void onSuccess(FileDownloadTask.TaskSnapshot taskSnapshot) {
                // Download complete, initialize the detector
                hideLoadingUI();
                initializeDetector(localFile);
            }
        })
        .addOnFailureListener(new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception exception) {
                // Handle download error
                showErrorUI();
            }
        });
}

private void initializeDetector(File modelFile) {
    LocalModel localModel = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
    CustomObjectDetectorOptions customObjectDetectorOptions =
            new CustomObjectDetectorOptions.Builder(localModel)
                    .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                    .enableClassification()
                    .build();
    ObjectDetector objectDetector = ObjectDetection.getClient(customObjectDetectorOptions);
    // Enable ML-related UI features here
    enableMLFeatures(objectDetector);
}

3. 準備輸入圖片

從圖片建立 InputImage 物件。 物件偵測器會直接從 Bitmap、NV21 ByteBuffer 或 YUV_420_888 media.Image 執行。如果您可直接存取其中一個來源,建議從這些來源建構 InputImage。如果您從其他來源建構 InputImage,我們會為您處理內部轉換,但效率可能較低。

您可以從不同來源建立 InputImage 物件,說明如下:

使用 media.Image

如要從 media.Image 物件建立 InputImage 物件 (例如從裝置的相機擷取圖片時),請將 media.Image 物件和圖片的旋轉角度傳遞至 InputImage.fromMediaImage()

如果您使用 CameraX 程式庫,OnImageCapturedListenerImageAnalysis.Analyzer 類別會為您計算旋轉值。

Kotlin

private class YourImageAnalyzer : ImageAnalysis.Analyzer {

    override fun analyze(imageProxy: ImageProxy) {
        val mediaImage = imageProxy.image
        if (mediaImage != null) {
            val image = InputImage.fromMediaImage(mediaImage, imageProxy.imageInfo.rotationDegrees)
            // Pass image to an ML Kit Vision API
            // ...
        }
    }
}

Java

private class YourAnalyzer implements ImageAnalysis.Analyzer {

    @Override
    public void analyze(ImageProxy imageProxy) {
        Image mediaImage = imageProxy.getImage();
        if (mediaImage != null) {
          InputImage image =
                InputImage.fromMediaImage(mediaImage, imageProxy.getImageInfo().getRotationDegrees());
          // Pass image to an ML Kit Vision API
          // ...
        }
    }
}

如果您使用的相機程式庫未提供圖片的旋轉角度,可以根據裝置的旋轉角度和裝置中相機感應器的方向計算:

Kotlin

private val ORIENTATIONS = SparseIntArray()

init {
    ORIENTATIONS.append(Surface.ROTATION_0, 0)
    ORIENTATIONS.append(Surface.ROTATION_90, 90)
    ORIENTATIONS.append(Surface.ROTATION_180, 180)
    ORIENTATIONS.append(Surface.ROTATION_270, 270)
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
@Throws(CameraAccessException::class)
private fun getRotationCompensation(cameraId: String, activity: Activity, isFrontFacing: Boolean): Int {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    val deviceRotation = activity.windowManager.defaultDisplay.rotation
    var rotationCompensation = ORIENTATIONS.get(deviceRotation)

    // Get the device's sensor orientation.
    val cameraManager = activity.getSystemService(CAMERA_SERVICE) as CameraManager
    val sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION)!!

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360
    }
    return rotationCompensation
}

Java

private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
static {
    ORIENTATIONS.append(Surface.ROTATION_0, 0);
    ORIENTATIONS.append(Surface.ROTATION_90, 90);
    ORIENTATIONS.append(Surface.ROTATION_180, 180);
    ORIENTATIONS.append(Surface.ROTATION_270, 270);
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
private int getRotationCompensation(String cameraId, Activity activity, boolean isFrontFacing)
        throws CameraAccessException {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    int deviceRotation = activity.getWindowManager().getDefaultDisplay().getRotation();
    int rotationCompensation = ORIENTATIONS.get(deviceRotation);

    // Get the device's sensor orientation.
    CameraManager cameraManager = (CameraManager) activity.getSystemService(CAMERA_SERVICE);
    int sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION);

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360;
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360;
    }
    return rotationCompensation;
}

接著,將 media.Image 物件和旋轉角度值傳遞至 InputImage.fromMediaImage()

Kotlin

val image = InputImage.fromMediaImage(mediaImage, rotation)

Java

InputImage image = InputImage.fromMediaImage(mediaImage, rotation);

使用檔案 URI

如要從檔案 URI 建立 InputImage 物件,請將應用程式內容和檔案 URI 傳遞至 InputImage.fromFilePath()。當您使用 ACTION_GET_CONTENT intent 提示使用者從相片庫應用程式選取圖片時,這項功能就非常實用。

Kotlin

val image: InputImage
try {
    image = InputImage.fromFilePath(context, uri)
} catch (e: IOException) {
    e.printStackTrace()
}

Java

InputImage image;
try {
    image = InputImage.fromFilePath(context, uri);
} catch (IOException e) {
    e.printStackTrace();
}

使用 ByteBufferByteArray

如要從 ByteBufferByteArray 建立 InputImage 物件,請先計算圖片旋轉角度,如先前所述的 media.Image 輸入內容。接著,使用緩衝區或陣列建立 InputImage 物件,並提供圖片的高度、寬度、顏色編碼格式和旋轉角度:

Kotlin

val image = InputImage.fromByteBuffer(
        byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)
// Or:
val image = InputImage.fromByteArray(
        byteArray,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)

Java

InputImage image = InputImage.fromByteBuffer(byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);
// Or:
InputImage image = InputImage.fromByteArray(
        byteArray,
        /* image width */480,
        /* image height */360,
        rotation,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);

使用 Bitmap

如要從 Bitmap 物件建立 InputImage 物件,請進行下列宣告:

Kotlin

val image = InputImage.fromBitmap(bitmap, 0)

Java

InputImage image = InputImage.fromBitmap(bitmap, rotationDegree);

圖片會以 Bitmap 物件和旋轉角度表示。

4. 執行物件偵測工具

Kotlin

objectDetector
    .process(image)
    .addOnFailureListener(e -> {...})
    .addOnSuccessListener(results -> {
        for (detectedObject in results) {
          // ...
        }
    });

Java

objectDetector
    .process(image)
    .addOnFailureListener(e -> {...})
    .addOnSuccessListener(results -> {
        for (DetectedObject detectedObject : results) {
          // ...
        }
    });

5. 取得標示物件的相關資訊

如果呼叫 process() 成功,系統會將 DetectedObject 清單傳遞至成功事件監聽器。

每個 DetectedObject 都包含下列屬性:

定界框 Rect:表示物件在圖片中的位置。
追蹤 ID 用來識別跨圖像物件的整數。在 SINGLE_IMAGE_MODE 中為空值。
標籤
標籤說明 標籤的文字說明。只有在 LiteRT 模型的 中繼資料包含標籤說明時,才會傳回這項資訊。
標籤索引 分類器支援的所有標籤中,這個標籤的索引。
標籤可信度 物件分類的信心值。

Kotlin

// The list of detected objects contains one item if multiple
// object detection wasn't enabled.
for (detectedObject in results) {
    val boundingBox = detectedObject.boundingBox
    val trackingId = detectedObject.trackingId
    for (label in detectedObject.labels) {
      val text = label.text
      val index = label.index
      val confidence = label.confidence
    }
}

Java

// The list of detected objects contains one item if multiple
// object detection wasn't enabled.
for (DetectedObject detectedObject : results) {
  Rect boundingBox = detectedObject.getBoundingBox();
  Integer trackingId = detectedObject.getTrackingId();
  for (Label label : detectedObject.getLabels()) {
    String text = label.getText();
    int index = label.getIndex();
    float confidence = label.getConfidence();
  }
}

確保提供優質的使用者體驗

為提供最佳使用者體驗,請確保應用程式遵循下列規範:

  • 成功進行物件偵測與否,取決於物件的視覺複雜度。如要偵測視覺特徵較少的物件,可能需要讓物件在圖片中占據較大比例。您應向使用者提供指引,說明如何擷取適合偵測目標物件的輸入內容。
  • 使用分類功能時,如要偵測不屬於支援類別的物件,請針對不明物件實作特殊處理方式。

此外,也請參閱 ML Kit Material Design 展示應用程式,以及機器學習輔助功能適用的 Material Design 模式集合。

提升效能

如要在即時應用程式中使用物件偵測功能,請按照下列指南操作,以達到最佳影格速率:

  • 在即時應用程式中使用串流模式時,請勿使用多個物件偵測功能,因為大多數裝置無法產生足夠的影格速率。

  • 如果您使用 Cameracamera2 API,請節流對偵測器的呼叫。如果偵測器執行期間有新的影片影格可用,請捨棄該影格。如需範例,請參閱快速入門範例應用程式中的 VisionProcessorBase 類別。
  • 如果您使用 CameraX API,請務必將回壓策略設為預設值 ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST。這可確保系統一次只會傳送一張圖片進行分析。如果分析器忙碌時產生更多圖片,系統會自動捨棄這些圖片,不會排隊等待傳送。呼叫 ImageProxy.close() 關閉正在分析的圖片後,系統會傳送下一個最新圖片。
  • 如果使用偵測器的輸出內容,在輸入圖片上疊加圖像,請先從 ML Kit 取得結果,然後在單一步驟中算繪圖片並疊加圖像。每個輸入影格只會轉譯到顯示介面一次。如需範例,請參閱快速入門範例應用程式中的 CameraSourcePreview GraphicOverlay 類別。
  • 如果您使用 Camera2 API,請以 ImageFormat.YUV_420_888 格式擷取圖片。如果您使用舊版 Camera API,請以 ImageFormat.NV21 格式擷取圖片。