使用自定义模型给图片加标签 (Android)

您可以使用机器学习套件识别图片中的实体并为其添加标签。此 API 支持各种自定义图片分类模型。如需有关模型兼容性要求、预训练模型的位置以及如何训练自有模型的指导,请参阅使用机器学习套件的自定义模型

您可以通过以下两种方式将图片标记与自定义模型集成:将流水线捆绑为应用的一部分,或使用依赖于 Google Play 服务的非捆绑流水线。如果您选择非捆绑流水线,应用会更小。有关详情,请查看下表。

捆绑非捆绑
库名称com.google.mlkit:image-labeling-customcom.google.android.gms:play-services-mlkit-image-labeling-custom

实现
流水线在构建时会静态链接到您的应用。流水线是使用 Google Play 服务动态下载的。
应用大小大小增加约 3.8 MB。大小增加约 200 KB。
初始化时间流水线可立即使用。首次使用前可能需要等待流水线下载完毕。
API 生命周期阶段正式版 (GA)Beta 版

您可以通过以下两种方式集成自定义模型:可以将模型嵌入应用的资源文件夹中以捆绑模型,也可以从 Firebase 动态下载模型。下表比较了这两个选项。

捆绑模型 托管模型
模型是应用的 APK 的一部分,这会增加 APK 的大小。 模型不是 APK 的一部分。通过上传到 Cloud Storage 进行托管。我们建议使用 Cloud Storage for Firebase
即使 Android 设备处于离线状态,模型也可立即使用 您的应用必须包含按需下载模型的代码
不需要 Firebase 项目 需要 Firebase 项目(如果使用 Cloud Storage for Firebase)。
您必须重新发布应用才能更新模型 无需重新发布应用即可推送模型更新
没有内置 A/B 测试 使用 Firebase Remote Config 进行 A/B 测试

试试看

准备工作

  1. 请务必在项目级 build.gradle.kts 文件中的 buildscriptallprojects 部分添加 Google 的 Maven 制品库。

  2. 将 Android 版机器学习套件库的依赖项添加到模块的应用级 Gradle 文件(通常为 app/build.gradle.kts)。根据您的需求选择以下依赖项之一:

    如需将流水线与您的应用捆绑在一起,请执行以下操作

    dependencies {
      // ...
      // Use this dependency to bundle the pipeline with your app
      implementation("com.google.mlkit:image-labeling-custom:17.0.3")
    }
    

    如需在 Google Play 服务中使用流水线,请执行以下操作

    dependencies {
      // ...
      // Use this dependency to use the dynamically downloaded pipeline in Google Play services
      implementation("com.google.android.gms:play-services-mlkit-image-labeling-custom:16.0.0-beta5")
    }
    
  3. 如果您选择在 Google Play 服务中使用该流水线,则可将应用配置为从 Play 商店安装后自动将该流水线下载到设备。为此,请将以下声明添加到应用的 AndroidManifest.xml 文件中:

    <application ...>
        ...
        <meta-data
            android:name="com.google.mlkit.vision.DEPENDENCIES"
            android:value="custom_ica" />
        <!-- To use multiple downloads: android:value="custom_ica,download2,download3" -->
    </application>
    

    您还可以通过 Google Play 服务 ModuleInstallClient API 显式检查流水线可用性并请求下载。

    如果您未启用安装时流水线下载或请求显式下载,系统会在您首次运行标记器时下载流水线。您在下载完成之前提出的请求不会产生任何结果。

  4. 如果您想使用 Cloud Storage for Firebase 下载模型,请务必将 Firebase 添加到您的 Android 项目(如果尚未添加)。捆绑模型时不需要这样做。

1. 加载模型

您可以从本地捆绑的来源或远程托管的来源加载模型。

配置本地模型来源

如需将模型与您的应用捆绑在一起,请执行以下操作:

  1. 将模型文件(通常以 .tflite.lite 结尾)复制到应用的 assets/ 文件夹。(您可能需要先创建此文件夹,方法是右键点击 app/ 文件夹,然后依次点击新建 > 文件夹 > Assets 文件夹。)

  2. 创建一个 LocalModel 对象,指定模型文件的路径:

    Kotlin

    val localModel = LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build()

    Java

    LocalModel localModel =
        new LocalModel.Builder()
            .setAssetFilePath("model.tflite")
            // or .setAbsoluteFilePath(absolute path to model file)
            // or .setUri(URI to model file)
            .build();

配置远程托管的模型来源

如需使用远程托管的模型,您必须使用自己的应用逻辑将模型文件下载到设备的本地存储空间,然后将其作为本地模型加载。我们建议使用 Cloud Storage for Firebase 来托管模型。如需了解实现详情,请参阅 Firebase ML 到 Cloud Storage 迁移指南

配置图片标记器

配置模型来源后,根据其中一个模型创建 ImageLabeler 对象。

提供的选项如下:

选项
confidenceThreshold

检测到的标签的最低置信度得分。如果未设置,系统将使用模型元数据指定的任何分类器阈值。如果模型不包含任何元数据,或者元数据未指定分类器阈值,系统将使用默认阈值 0.0。

maxResultCount

要返回的标签数量上限。如果未设置,系统将使用默认值 10。

如果您只有本地捆绑的模型,只需根据您的 LocalModel 对象创建一个标记器即可:

Kotlin

val customImageLabelerOptions = CustomImageLabelerOptions.Builder(localModel)
    .setConfidenceThreshold(0.5f)
    .setMaxResultCount(5)
    .build()
val labeler = ImageLabeling.getClient(customImageLabelerOptions)

Java

CustomImageLabelerOptions customImageLabelerOptions =
        new CustomImageLabelerOptions.Builder(localModel)
            .setConfidenceThreshold(0.5f)
            .setMaxResultCount(5)
            .build();
ImageLabeler labeler = ImageLabeling.getClient(customImageLabelerOptions);

如果您使用的是远程托管的模型,则必须在运行之前检查该模型是否已下载。

虽然您只需在运行标签器之前确认这一点,但如果您同时拥有远程托管模型和本地捆绑模型,那么在实例化图片标签器时执行此检查可能是有意义的:如果远程模型已下载,则从该模型创建标签器;否则,从本地模型创建标签器。

Kotlin

val modelFile = File(context.cacheDir, "my_downloaded_model.tflite")
val model = if (modelFile.exists()) {
    // Use the downloaded model if available
    LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
} else {
    // Fall back to the bundled model
    LocalModel.Builder().setAssetFilePath("model.tflite").build()
}
val options = CustomImageLabelerOptions.Builder(model)
    .setConfidenceThreshold(0.5f)
    .setMaxResultCount(5)
    .build()
val labeler = ImageLabeling.getClient(options)

Java

File modelFile = new File(context.getCacheDir(), "my_downloaded_model.tflite");
LocalModel model;
if (modelFile.exists()) {
    // Use the downloaded model if available
    model = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
} else {
    // Fall back to the bundled model
    model = new LocalModel.Builder().setAssetFilePath("model.tflite").build();
}
CustomImageLabelerOptions options = new CustomImageLabelerOptions.Builder(model)
    .setConfidenceThreshold(0.5f)
    .setMaxResultCount(5)
    .build();
ImageLabeler labeler = ImageLabeling.getClient(options);

如果您只有远程托管的模型,则应停用与模型相关的功能(例如使界面的一部分变灰或将其隐藏),直到您确认模型已下载。

Kotlin

val localFile = File(context.cacheDir, "my_remote_model.tflite")
if (localFile.exists()) {
    initializeLabeler(localFile)
} else {
    showLoadingUI()
    val storage = Firebase.storage
    val modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite")
    modelRef.getFile(localFile)
        .addOnSuccessListener {
            hideLoadingUI()
            initializeLabeler(localFile)
        }
        .addOnFailureListener {
            showErrorUI()
        }
}

private fun initializeLabeler(modelFile: File) {
    val localModel = LocalModel.Builder().setAbsoluteFilePath(modelFile.absolutePath).build()
    val options = CustomImageLabelerOptions.Builder(localModel).build()
    val labeler = ImageLabeling.getClient(options)
    enableMLFeatures(labeler)
}

Java

File localFile = new File(context.getCacheDir(), "my_remote_model.tflite");
if (localFile.exists()) {
    initializeLabeler(localFile);
} else {
    showLoadingUI();
    FirebaseStorage storage = FirebaseStorage.getInstance();
    StorageReference modelRef = storage.getReferenceFromUrl("gs://YOUR_BUCKET/path/to/model.tflite");
    modelRef.getFile(localFile)
        .addOnSuccessListener(new OnSuccessListener<FileDownloadTask.TaskSnapshot>() {
            @Override
            public void onSuccess(FileDownloadTask.TaskSnapshot taskSnapshot) {
                hideLoadingUI();
                initializeLabeler(localFile);
            }
        })
        .addOnFailureListener(new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception exception) {
                showErrorUI();
            }
        });
}

private void initializeLabeler(File modelFile) {
    LocalModel localModel = new LocalModel.Builder().setAbsoluteFilePath(modelFile.getAbsolutePath()).build();
    CustomImageLabelerOptions options = new CustomImageLabelerOptions.Builder(localModel).build();
    ImageLabeler labeler = ImageLabeling.getClient(options);
    enableMLFeatures(labeler);
}

2. 准备输入图片

接下来,基于每个您想要加标签的图片创建一个 InputImage 对象。使用 Bitmap 或 YUV_420_888 media.Image(如果您使用 Camera2 API)时,图片标记器的运行速度最快;建议您尽量使用这两种格式的图片。

您可以基于不同来源创建 InputImage 对象,下文分别介绍了具体方法。

使用 media.Image

如需基于 media.Image 对象创建 InputImage 对象(例如从设备的相机捕获图片时),请将 media.Image 对象和图片的旋转角度传递给 InputImage.fromMediaImage()

如果您使用 CameraX 库,OnImageCapturedListenerImageAnalysis.Analyzer 类会为您计算旋转角度值。

Kotlin

private class YourImageAnalyzer : ImageAnalysis.Analyzer {

    override fun analyze(imageProxy: ImageProxy) {
        val mediaImage = imageProxy.image
        if (mediaImage != null) {
            val image = InputImage.fromMediaImage(mediaImage, imageProxy.imageInfo.rotationDegrees)
            // Pass image to an ML Kit Vision API
            // ...
        }
    }
}

Java

private class YourAnalyzer implements ImageAnalysis.Analyzer {

    @Override
    public void analyze(ImageProxy imageProxy) {
        Image mediaImage = imageProxy.getImage();
        if (mediaImage != null) {
          InputImage image =
                InputImage.fromMediaImage(mediaImage, imageProxy.getImageInfo().getRotationDegrees());
          // Pass image to an ML Kit Vision API
          // ...
        }
    }
}

如果您不使用可提供图片旋转角度的相机库,则可以根据设备的旋转角度和设备中相机传感器的朝向来计算旋转角度:

Kotlin

private val ORIENTATIONS = SparseIntArray()

init {
    ORIENTATIONS.append(Surface.ROTATION_0, 0)
    ORIENTATIONS.append(Surface.ROTATION_90, 90)
    ORIENTATIONS.append(Surface.ROTATION_180, 180)
    ORIENTATIONS.append(Surface.ROTATION_270, 270)
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
@Throws(CameraAccessException::class)
private fun getRotationCompensation(cameraId: String, activity: Activity, isFrontFacing: Boolean): Int {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    val deviceRotation = activity.windowManager.defaultDisplay.rotation
    var rotationCompensation = ORIENTATIONS.get(deviceRotation)

    // Get the device's sensor orientation.
    val cameraManager = activity.getSystemService(CAMERA_SERVICE) as CameraManager
    val sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION)!!

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360
    }
    return rotationCompensation
}

Java

private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
static {
    ORIENTATIONS.append(Surface.ROTATION_0, 0);
    ORIENTATIONS.append(Surface.ROTATION_90, 90);
    ORIENTATIONS.append(Surface.ROTATION_180, 180);
    ORIENTATIONS.append(Surface.ROTATION_270, 270);
}

/**
 * Get the angle by which an image must be rotated given the device's current
 * orientation.
 */
@RequiresApi(api = Build.VERSION_CODES.LOLLIPOP)
private int getRotationCompensation(String cameraId, Activity activity, boolean isFrontFacing)
        throws CameraAccessException {
    // Get the device's current rotation relative to its "native" orientation.
    // Then, from the ORIENTATIONS table, look up the angle the image must be
    // rotated to compensate for the device's rotation.
    int deviceRotation = activity.getWindowManager().getDefaultDisplay().getRotation();
    int rotationCompensation = ORIENTATIONS.get(deviceRotation);

    // Get the device's sensor orientation.
    CameraManager cameraManager = (CameraManager) activity.getSystemService(CAMERA_SERVICE);
    int sensorOrientation = cameraManager
            .getCameraCharacteristics(cameraId)
            .get(CameraCharacteristics.SENSOR_ORIENTATION);

    if (isFrontFacing) {
        rotationCompensation = (sensorOrientation + rotationCompensation) % 360;
    } else { // back-facing
        rotationCompensation = (sensorOrientation - rotationCompensation + 360) % 360;
    }
    return rotationCompensation;
}

然后,将 media.Image 对象及其旋转角度值传递给 InputImage.fromMediaImage()

Kotlin

val image = InputImage.fromMediaImage(mediaImage, rotation)

Java

InputImage image = InputImage.fromMediaImage(mediaImage, rotation);

使用文件 URI

如需基于文件 URI 创建 InputImage 对象,请将应用上下文和文件 URI 传递给 InputImage.fromFilePath()。如果您使用 ACTION_GET_CONTENT intent 提示用户从相册应用中选择图片,这一操作会非常有用。

Kotlin

val image: InputImage
try {
    image = InputImage.fromFilePath(context, uri)
} catch (e: IOException) {
    e.printStackTrace()
}

Java

InputImage image;
try {
    image = InputImage.fromFilePath(context, uri);
} catch (IOException e) {
    e.printStackTrace();
}

使用 ByteBufferByteArray

如需基于 ByteBufferByteArray 创建 InputImage 对象,请首先按先前 media.Image 输入的说明计算图片旋转角度。然后,使用缓冲区或数组以及图片的高度、宽度、颜色编码格式和旋转角度创建 InputImage 对象:

Kotlin

val image = InputImage.fromByteBuffer(
        byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)
// Or:
val image = InputImage.fromByteArray(
        byteArray,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
)

Java

InputImage image = InputImage.fromByteBuffer(byteBuffer,
        /* image width */ 480,
        /* image height */ 360,
        rotationDegrees,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);
// Or:
InputImage image = InputImage.fromByteArray(
        byteArray,
        /* image width */480,
        /* image height */360,
        rotation,
        InputImage.IMAGE_FORMAT_NV21 // or IMAGE_FORMAT_YV12
);

使用 Bitmap

如需基于 Bitmap 对象创建 InputImage 对象,请进行以下声明:

Kotlin

val image = InputImage.fromBitmap(bitmap, 0)

Java

InputImage image = InputImage.fromBitmap(bitmap, rotationDegree);

图片由 Bitmap 对象以及旋转角度表示。

3. 运行图片标记器

如需给图片中的对象加标签,请将 image 对象传递给 ImageLabelerprocess() 方法。

Kotlin

labeler.process(image)
        .addOnSuccessListener { labels ->
            // Task completed successfully
            // ...
        }
        .addOnFailureListener { e ->
            // Task failed with an exception
            // ...
        }

Java

labeler.process(image)
        .addOnSuccessListener(new OnSuccessListener<List<ImageLabel>>() {
            @Override
            public void onSuccess(List<ImageLabel> labels) {
                // Task completed successfully
                // ...
            }
        })
        .addOnFailureListener(new OnFailureListener() {
            @Override
            public void onFailure(@NonNull Exception e) {
                // Task failed with an exception
                // ...
            }
        });

4. 获取有关已加标签的实体的信息

如果为图片添加标签的操作成功完成,系统会向成功监听器传递一组 ImageLabel 对象。每个 ImageLabel 对象代表图片中加了标签的某个事物。您可以获取每个标签的文本说明(如果在 LiteRT 模型文件的元数据中可用)、置信度分数和索引。例如:

Kotlin

for (label in labels) {
    val text = label.text
    val confidence = label.confidence
    val index = label.index
}

Java

for (ImageLabel label : labels) {
    String text = label.getText();
    float confidence = label.getConfidence();
    int index = label.getIndex();
}

提高实时性能的相关提示

如果要在实时应用中为图片加标签,请遵循以下准则以实现最佳帧速率:

  • 如果您使用 Cameracamera2 API,请限制对图片标记器的调用次数。如果在图片标记器运行时有新的视频帧可用,请丢弃该帧。如需查看示例,请参阅快速入门示例应用中的 VisionProcessorBase 类。
  • 如果您使用 CameraX API,请确保将反压策略设置为其默认值 ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST。这样可确保一次只传送一张图片进行分析。如果分析器繁忙时生成了更多图片,这些图片将被自动丢弃,而不会排队等待传送。通过调用 ImageProxy.close() 关闭正在分析的图片后,系统会传送下一张最新图片。
  • 如果要将图片标记器的输出作为图形叠加在输入图片上,请先从机器学习套件获取结果,然后在一个步骤中完成图片的呈现和叠加。这样,对于每个输入帧,系统只会向显示界面呈现一次。如需查看示例,请参阅快速入门示例应用中的 CameraSourcePreview GraphicOverlay 类。
  • 如果您使用 Camera2 API,请以 ImageFormat.YUV_420_888 格式捕获图片。如果您使用旧版 Camera API,请以 ImageFormat.NV21 格式捕获图片。