使用机器学习套件的自定义模型

默认情况下,机器学习套件的 API 使用 Google 训练的机器学习模型。 这些模型旨在涵盖各种应用。不过,某些用例需要更具针对性的模型。因此,某些机器学习套件 API 现在允许您将默认模型替换为自定义 LiteRT 模型。

图像标签和对象检测与跟踪 API 都支持自定义图像分类模型。它们与 TensorFlow Hub 上精选的优质预训练模型或您使用 TensorFlow 或 AutoML 训练的自定义模型兼容。

如果您需要针对其他领域或用例的自定义解决方案,请访问 设备端机器学习页面,获取有关 Google 所有 设备端机器学习解决方案和工具的指南。

将机器学习套件与自定义模型搭配使用的优势

将自定义图像分类模型与机器学习套件搭配使用的优势包括:

  • 易于使用的高级 API - 无需处理低级模型 输入/输出、处理图像预处理/后处理或构建处理 流水线。
  • 无需自行处理标签映射,机器学习套件会从 LiteRT 模型元数据中提取标签并为您执行映射。
  • 支持来自各种来源的自定义模型,从 TensorFlow Hub 上发布的预训练 模型到使用 TensorFlow 或 AutoML 训练的新模型。
  • 针对与 Android 的 Camera API 集成进行了优化。

具体而言,对于对象检测和跟踪

  • 先定位对象,然后仅对相关图像区域运行分类器,从而提高分类准确率
  • 在检测和分类对象时立即向用户提供反馈,从而提供实时互动体验

使用预训练图像分类模型

您可以根据一组条件使用预训练 LiteRT 模型。通过 TensorFlow Hub,我们提供了一组经过审核的模型(来自 Google 或其他模型创建者),这些模型符合这些条件。

使用 TensorFlow Hub 上发布的模型

TensorFlow Hub 提供了各种模型创建者提供的各种预训练图像 分类模型,这些模型可与 图像标签和对象检测与跟踪 API 搭配使用。请按以下步骤操作。

  1. 机器学习套件兼容模型集合中选择一个模型。
  2. 从模型详情页面下载 .tflite 模型文件。如果可用,请选择包含元数据的模型格式。
  3. 按照图像标签 API对象检测与跟踪 API 的指南,了解如何将模型文件与项目捆绑在一起,并在 Android 或 iOS 应用中使用该模型文件。

训练您自己的图像分类模型

如果没有预训练图像分类模型符合您的需求,您可以通过多种方式训练自己的 LiteRT 模型,其中一些方式将在以下部分中进行概述和更详细的讨论。

训练您自己的图像分类模型的选项
AutoML
  • 通过 Google Cloud AI 提供
  • 创建最先进的图像分类模型
  • 评估性能和大小之间的平衡
将 TensorFlow 模型转换为 LiteRT
  • 使用 TensorFlow 训练模型,然后将其转换为 LiteRT

AutoML

图像标签和对象检测与跟踪 API 中的自定义模型支持使用 AutoML 训练的图像分类模型。这些 API 还 支持下载使用 Cloud Storage 托管的模型。

如需详细了解如何在 Android 和 iOS 应用中使用使用 AutoML 训练的模型,请根据您的用例按照每个 API 的自定义模型指南进行操作。

使用 LiteRT 转换器创建的模型

如果您有现有的 TensorFlow 图像分类模型,可以使用 LiteRT 转换器 对其进行转换 。确保创建的模型符合以下兼容性要求。

如需详细了解如何在 Android 和 iOS 应用中使用 LiteRT 模型, 请根据您的用例按照图像标签 API对象检测与跟踪 API 的指南进行操作。

LiteRT 模型兼容性

您可以使用任何预训练 LiteRT 图像分类模型,前提是该模型符合以下要求:

张量

  • 模型必须只有一个输入张量,且具有以下限制:
    • 数据采用 RGB 像素格式。
    • 数据类型为 UINT8 或 FLOAT32。如果输入张量类型为 FLOAT32,则必须通过附加 元数据来指定 NormalizationOptions。
    • 张量具有 4 个维度:BxHxWxC,其中:
      • B 是批次大小。它必须为 1(不支持对较大批次进行推理)。
      • W 和 H 是输入宽度和高度。
      • C 是预期通道数。它必须为 3。
  • 模型必须至少有一个输出张量,其中包含 N 个类,并且具有 2 个或 4 个维度:
    • (1xN)
    • (1x1x1xN)
  • 仅完全支持单头模型。多头模型可能会输出意外结果。

元数据

您可以按照 向 LiteRT 模型添加元数据中所述,向 LiteRT 文件添加元数据。

如需使用具有 FLOAT32 输入张量的模型,您必须在元数据中指定 NormalizationOptions

我们还建议您将此元数据附加到输出张量 TensorMetadata