使用自定义分类模型检测、跟踪和分类对象 (Android)

您可以使用机器学习套件来检测和跟踪连续视频帧中的对象。

当您向机器学习套件传递图片时,它会检测图片中的最多五个对象以及每个对象在图片中的位置。检测视频流中的对象时,每个对象都有一个唯一的 ID,您可以使用此 ID 来跨帧跟踪对象。

您可以使用自定义图片分类模型对检测到的对象进行分类。如需了解模型兼容性要求、在哪里可以找到预训练模型以及如何训练自己的模型,请参阅使用机器学习套件的自定义模型

有两种方法可以集成自定义模型。您可以将模型放入应用的资源文件夹中以捆绑该模型,也可以从 Cloud Storage 动态下载该模型。下表对这两种选项进行了比较。

捆绑的模型 托管的模型
模型是应用的 APK 的一部分,这会增加 APK 的大小。 模型不是 APK 的一部分。它是通过上传到 Cloud Storage 进行托管的。我们建议使用 Cloud Storage for Firebase
即使 Android 设备处于离线状态,模型也可立即使用 您的应用必须包含按需下载模型的代码
不需要 Firebase 项目 需要 Firebase 项目(如果使用 Cloud Storage for Firebase)。
您必须重新发布应用才能更新模型 无需重新发布应用即可推送模型更新
没有内置 A/B 测试 使用 Firebase Remote Config 进行 A/B 测试

试试看

准备工作

1. 请务必在项目级 build.gradle.kts 文件中的 buildscriptallprojects 部分添加 Google 的 Maven 制品库。

  1. 将 Android 版机器学习套件库的依赖项添加到模块的应用级 Gradle 文件(通常为 app/build.gradle.kts):

    如需将模型与您的应用捆绑在一起,请执行以下操作:

    dependencies {
      // ...
      // Object detection & tracking feature with custom bundled model
      implementation("com.google.mlkit:object-detection-custom:17.0.2")
    }
    
  2. 如果您想从 Cloud Storage for Firebase 下载模型,请 务必将 Firebase 添加到您的 Android 项目(如果尚未添加)。捆绑 模型时不需要这样做。

1. 加载模型

您可以从本地捆绑的来源或远程托管的来源加载模型。

配置本地模型来源

如需将模型与您的应用捆绑在一起,请执行以下操作:

  1. 将模型文件(通常以 .tflite.lite 结尾)复制到应用的 assets/ 文件夹。(您可能需要先创建此文件夹,方法是右键点击 app/ 文件夹,然后依次点击新建 > 文件夹 > Assets 文件夹 。)

  2. 创建一个 LocalModel 对象,指定模型文件的路径:

    Kotlin

    val localModel = LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build()

    Java

    LocalModel localModel =
        new LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build();

配置远程托管的模型来源

如需使用远程托管的模型,您必须使用自己的应用逻辑将模型文件下载到设备的本地存储空间,然后将其作为本地模型加载。我们建议使用 Cloud Storage for Firebase 来托管模型。如需了解 实现详情,请参阅 Firebase ML 到 Cloud Storage 迁移指南

2. 配置对象检测器

配置模型来源后,使用 CustomObjectDetectorOptions 对象为您的使用场景配置对象检测器。您可以更改以下设置:

对象检测器设置
检测模式 STREAM_MODE (默认) | SINGLE_IMAGE_MODE

STREAM_MODE(默认)下,对象检测器以低延迟高速运行,但在前几次调用检测器时可能会产生不完整的结果(例如未指定的边界框或类别标签)。此外,在 STREAM_MODE, 检测器会为对象分配跟踪 ID,您可以使用该 ID 来 跨帧跟踪对象。如果您想要跟踪 对象,或者对延迟有要求(例如在实时处理 视频流时),请使用此模式。

SINGLE_IMAGE_MODE 下,对象检测器会在确定对象的边界框后返回结果。如果您 还启用了分类,则它会在边界框和类别标签都可用后返回结果。因此, 此模式下的检测延迟可能较高。此外,在 SINGLE_IMAGE_MODE 下,不会分配跟踪 ID。如果不计较延迟高低,且不想处理不完整的结果,请使用此模式。

检测和跟踪多个对象 false (默认) | true

是检测和跟踪最多五个对象,还是仅检测和跟踪最 突出的对象(默认)。

对对象进行分类 false (默认) | true

是否使用提供的 自定义分类器模型对检测到的对象进行分类。如需使用自定义分类 模型,您需要将此项设置为 true

分类置信度阈值

检测到的标签的最低置信度分数。如果未设置,系统将使用模型的元数据指定的任何 分类器阈值。 如果模型不包含任何元数据,或者元数据未 指定分类器阈值,则系统将使用默认阈值 0.0。

每个对象的标签数上限

检测器将返回的每个对象的标签数上限。 如果未设置,系统将使用默认值 10。

对象检测和跟踪 API 针对以下两个核心使用场景进行了优化:

  • 实时检测和跟踪相机取景器中最突出的对象。
  • 检测静态图片中的多个对象。

如需为这些使用场景配置 API(使用本地捆绑的模型),请运行以下代码:

Kotlin

// Live detection and tracking
val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(localModel)
        .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE)
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

// Multiple object detection in static images
val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(localModel)
        .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
        .enableMultipleObjects()
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

val objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions)

Java

// Live detection and tracking
CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(localModel)
                .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE)
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

// Multiple object detection in static images
CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(localModel)
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableMultipleObjects()
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

ObjectDetector objectDetector =
    ObjectDetection.getClient(customObjectDetectorOptions);

如果您使用的是远程托管的模型,则必须在运行之前检查该模型是否已下载。

虽然您只需在运行检测器之前确认这一点,但如果您同时拥有远程托管模型和本地捆绑模型,则可以考虑在实例化图片检测器时执行此检查:如果已下载,则根据远程模型创建检测器,否则根据本地模型进行创建。

Kotlin

val modelFile = File(context.cacheDir, "my_remote_model.tflite")

val model = if (modelFile.exists()) {
    // Use the downloaded model if available
    LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
} else {
    // Fall back to the bundled model
    LocalModel.Builder().setAssetFilePath("model.tflite").build()
}

val customObjectDetectorOptions =
        CustomObjectDetectorOptions.Builder(model)
        .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
        .enableClassification()
        .setClassificationConfidenceThreshold(0.5f)
        .setMaxPerObjectLabelCount(3)
        .build()

val objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions)

Java

File modelFile = new File(context.getCacheDir(), "my_remote_model.tflite");

LocalModel model;
if (modelFile.exists()) {
    // Use the downloaded model if available
    model = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
} else {
    // Fall back to the bundled model
    model = new LocalModel.Builder().setAssetFilePath("model.tflite").build();
}

CustomObjectDetectorOptions customObjectDetectorOptions =
        new CustomObjectDetectorOptions.Builder(model)
                .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                .enableClassification()
                .setClassificationConfidenceThreshold(0.5f)
                .setMaxPerObjectLabelCount(3)
                .build();

ObjectDetector objectDetector =
        ObjectDetection.getClient(customObjectDetectorOptions);

如果您只有远程托管的模型,则应停用与模型相关的功能(例如使界面的一部分变灰或将其隐藏),直到您确认模型已下载。

Kotlin

val localFile = File(context.cacheDir, "my_remote_model.tflite")
if (localFile.exists()) {
    // Model is already cached, initialize immediately
    initializeDetector(localFile)
} else {
    // Model is not yet available, show loading UI and start download
    showLoadingUI()
    val storage = Firebase.storage
    val modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite")
    modelRef.getFile(localFile)
        .addOnSuccessListener {
            // Download complete, initialize the detector
            hideLoadingUI()
            initializeDetector(localFile)
        }
        .addOnFailureListener {
            // Handle download error
            showErrorUI()
        }
}

private fun initializeDetector(modelFile: File) {
    val localModel = LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
    val customObjectDetectorOptions = CustomObjectDetectorOptions.Builder(localModel)
            .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
            .enableClassification()
            .build()
    val objectDetector = ObjectDetection.getClient(customObjectDetectorOptions)
    // Enable ML-related UI features here
    enableMLFeatures(objectDetector)
}

Java

File localFile = new File(context.getCacheDir(), "my_remote_model.tflite");
if (localFile.exists()) {
    // Model is already cached, initialize immediately
    initializeDetector(localFile);
} else {
    // Model is not yet available, show loading UI and start download
    showLoadingUI();
    FirebaseStorage storage = FirebaseStorage.getInstance();
    StorageReference modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite");
    modelRef.getFile(localFile)
        .addOnSuccessListener(new OnSuccessListener<FileDownloadTask.TaskSnapshot>() {
            @Override
            public void onSuccess(FileDownloadTask.TaskSnapshot taskSnapshot) {
                // Download complete, initialize the detector
                hideLoadingUI();
                initializeDetector(localFile);
            }
        })
        .addOnFailureListener(new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception exception) {
                // Handle download error
                showErrorUI();
            }
        });
}

private void initializeDetector(File modelFile) {
    LocalModel localModel = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
    CustomObjectDetectorOptions customObjectDetectorOptions =
            new CustomObjectDetectorOptions.Builder(localModel)
                    .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE)
                    .enableClassification()
                    .build();
    ObjectDetector objectDetector = ObjectDetection.getClient(customObjectDetectorOptions);
    // Enable ML-related UI features here
    enableMLFeatures(objectDetector);
}

3. 准备输入图片

根据您的图片创建 InputImage 对象。 对象检测器直接从 Bitmap、NV21 ByteBuffer 或 YUV_420_888 media.Image 运行。如果您可以直接访问其中一个来源,建议您从这些来源构建 InputImage。如果您从其他来源构建 InputImage,我们会为您在内部处理转换,但效率可能会较低。

您可以从不同的来源创建 InputImage 对象,下面将对每个来源进行说明。

使用 media.Image

如需基于 media.Image 对象创建 InputImage 对象(例如从设备的相机捕获图片时),请将 media.Image 对象和图片的旋转角度传递给 InputImage.fromMediaImage()

如果您使用 CameraX 库,OnImageCapturedListenerImageAnalysis.Analyzer 类会为您计算旋转角度值。

Kotlin

private class YourImageAnalyzer : ImageAnalysis.Analyzer {

    override fun analyze(imageProxy: ImageProxy) {
        val mediaImage = imageProxy.image
        if (mediaImage != null) {
            val image = InputImage.fromMediaImage(mediaImage, imageProxy.imageInfo.rotationDegrees)
            // Pass image to an ML Kit Vision API
            // ...
        }
    }
}

Java

private class YourAnalyzer implements ImageAnalysis.Analyzer {

    @Override
    public void analyze(ImageProxy imageProxy) {
        Image mediaImage = imageProxy.getImage();
        if (mediaImage != null) {
          InputImage image =
                InputImage.fromMediaImage(mediaImage, imageProxy.getImageInfo().getRotationDegrees());
          // Pass image to an ML Kit Vision API
          // ...
        }
    }
}

如果您不使用可提供图片旋转角度的相机库,则可以根据设备的旋转角度和设备中相机传感器的朝向来计算旋转角度:

Kotlin

private val ORIENTATIONS = SparseIntArray()

init {
    ORIENTATIONS.append(Surface.ROTATION_0, 0)
    ORIENTATIONS.append(Surface.ROTATION_90, 90)
    ORIENTATIONS.append(Surface.ROTATION_180, 180)
    ORIENTATIONS.append(Surface.ROTATION_270, 270)
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
@Throws(CameraAccessException::class)
private fun getRotationCompensation(cameraId: String, activity: Activity, isFrontFacing: Boolean): Int {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    val deviceRotation = activity.windowManager.defaultDisplay.rotation
    var rotationCompensation = ORIENTATIONS.get(deviceRotation)

    // Get the device's sensor orientation.
    val cameraManager = activity.getSystemService(CAMERA_SERVICE) as CameraManager
    val sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION)!!

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360
    }
    return rotationCompensation
}

Java

private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
static {
    ORIENTATIONS.append(Surface.ROTATION_0, 0);
    ORIENTATIONS.append(Surface.ROTATION_90, 90);
    ORIENTATIONS.append(Surface.ROTATION_180, 180);
    ORIENTATIONS.append(Surface.ROTATION_270, 270);
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
private int getRotationCompensation(String cameraId, Activity activity, boolean isFrontFacing)
        throws CameraAccessException {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    int deviceRotation = activity.getWindowManager().getDefaultDisplay().getRotation();
    int rotationCompensation = ORIENTATIONS.get(deviceRotation);

    // Get the device's sensor orientation.
    CameraManager cameraManager = (CameraManager) activity.getSystemService(CAMERA_SERVICE);
    int sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION);

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360;
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360;
    }
    return rotationCompensation;
}

然后,将 media.Image 对象及其旋转角度值传递给 InputImage.fromMediaImage()

Kotlin

val image = InputImage.fromMediaImage(mediaImage, rotation)

Java

InputImage image = InputImage.fromMediaImage(mediaImage, rotation);

使用文件 URI

如需基于文件 URI 创建 InputImage 对象,请将应用上下文和文件 URI 传递给 InputImage.fromFilePath()。如果您使用 ACTION_GET_CONTENT Intent 提示用户从相册应用中选择图片,这一操作会非常有用。

Kotlin

val image: InputImage
try {
    image = InputImage.fromFilePath(context, uri)
} catch (e: IOException) {
    e.printStackTrace()
}

Java

InputImage image;
try {
    image = InputImage.fromFilePath(context, uri);
} catch (IOException e) {
    e.printStackTrace();
}

使用 ByteBufferByteArray

如需基于 ByteBufferByteArray 创建 InputImage 对象,请首先按先前 media.Image 输入的说明计算图片 旋转角度。然后,使用缓冲区或数组以及图片的高度、宽度、颜色编码格式和旋转角度创建 InputImage 对象:

Kotlin

val image = InputImage.fromByteBuffer(
        byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)
// Or:
val image = InputImage.fromByteArray(
        byteArray,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)

Java

InputImage image = InputImage.fromByteBuffer(byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);
// Or:
InputImage image = InputImage.fromByteArray(
        byteArray,
        /* image width */480,
        /* image height */360,
        rotation,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);

使用 Bitmap

如需基于 Bitmap 对象创建 InputImage 对象,请进行以下声明:

Kotlin

val image = InputImage.fromBitmap(bitmap, 0)

Java

InputImage image = InputImage.fromBitmap(bitmap, rotationDegree);

图片由 Bitmap 对象以及旋转角度表示。

4. 运行对象检测器

Kotlin

objectDetector
    .process(image)
    .addOnFailureListener(e -> {...})
    .addOnSuccessListener(results -> {
        for (detectedObject in results) {
          // ...
        }
    });

Java

objectDetector
    .process(image)
    .addOnFailureListener(e -> {...})
    .addOnSuccessListener(results -> {
        for (DetectedObject detectedObject : results) {
          // ...
        }
    });

5. 获取已加标签的对象的相关信息

如果对 process() 的调用成功完成,系统会向成功监听器传递一组 DetectedObject

每个 DetectedObject 都包含以下属性:

边界框 一个 Rect,指示图片中对象的位置。
跟踪 ID 一个整数,用于跨图片识别对象。在 SINGLE_IMAGE_MODE 下为 Null。
标签
标签说明 标签的文本说明。仅当 LiteRT 模型的 元数据包含标签说明时才会返回。
标签索引 标签在分类器支持的所有标签中的 索引。
标签置信度 对象分类的置信度值。

Kotlin

// The list of detected objects contains one item if multiple
// object detection wasn't enabled.
for (detectedObject in results) {
    val boundingBox = detectedObject.boundingBox
    val trackingId = detectedObject.trackingId
    for (label in detectedObject.labels) {
      val text = label.text
      val index = label.index
      val confidence = label.confidence
    }
}

Java

// The list of detected objects contains one item if multiple
// object detection wasn't enabled.
for (DetectedObject detectedObject : results) {
  Rect boundingBox = detectedObject.getBoundingBox();
  Integer trackingId = detectedObject.getTrackingId();
  for (Label label : detectedObject.getLabels()) {
    String text = label.getText();
    int index = label.getIndex();
    float confidence = label.getConfidence();
  }
}

确保出色的用户体验

为了获得最佳用户体验,请在应用中遵循以下准则:

  • 对象检测成功与否取决于对象的视觉复杂性。为了能够被检测到,具有较少视觉特征的对象可能需要占据待检测图片的较大部分区域。您应为用户提供有关捕获输入的指导,该输入应适用于您要检测的对象类型。
  • 使用分类时,如果您要检测不完全归于受支持类别的对象,请对未知对象执行特殊处理。

另请参阅 机器学习套件 Material Design 展示应用和适用于机器学习所支持功能集的 Material Design 模式

提高性能

如果要在实时应用中使用对象检测,请遵循以下准则以实现最佳帧速率:

  • 在实时应用中使用流式传输模式时,请勿使用多个对象检测,因为大多数设备无法产生足够高的帧速率。

  • 如果您使用 Cameracamera2 API, 请限制对检测器的调用。如果在检测器运行时有新的视频 帧可用,请丢弃该帧。如需查看示例,请参阅快速入门示例应用中的 VisionProcessorBase 类。
  • 如果您使用 CameraX API,请确保将反压策略设置为其默认值 ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST。这样可确保一次只传送一张图片进行分析。如果分析器繁忙时生成了更多图片,这些图片将被自动丢弃,而不会排队等待传送。通过调用 ImageProxy.close() 关闭正在分析的图片后,系统会传送下一张最新图片。
  • 如果要将检测器的输出作为图形叠加在 输入图片上,请先从机器学习套件获取结果,然后在一个步骤中完成图片的呈现和叠加。这样,每个输入帧只需在显示表面 呈现一次。如需查看示例,请参阅快速入门示例应用中的 CameraSourcePreview GraphicOverlay 类。
  • 如果您使用 Camera2 API,请以 ImageFormat.YUV_420_888 格式捕获图片。如果您使用旧版 Camera API,请以 ImageFormat.NV21 格式捕获图片。