使用自定义分类模型检测、跟踪和分类对象 (iOS)

您可以使用机器学习套件来检测和跟踪连续视频帧中的对象。

当您向机器学习套件传递图片时,它会检测图片中最多五个对象以及每个对象在图片中的位置。检测视频流中的对象时,每个对象都有一个唯一 ID,您可以使用此 ID 来逐帧跟踪对象。

您可以使用自定义图片分类模型对检测到的对象进行分类。如需有关模型兼容性要求、预训练模型的位置以及如何训练自有模型的指导,请参阅使用机器学习套件的自定义模型

您可以通过以下两种方式集成自定义模型。您可以将模型放入应用的资源文件夹中以捆绑该模型,也可以从 Cloud Storage 动态下载该模型。下表比较了这两个选项。

捆绑模型 托管模型
模型是应用的 .ipa 文件的一部分,这会增加其大小。 该模型不属于应用的 .ipa 文件。它通过上传到 Cloud Storage 进行托管。我们建议使用 Cloud Storage for Firebase
即使 Android 设备处于离线状态,模型也可立即使用 您的应用必须包含按需下载模型的代码
不需要 Firebase 项目 需要 Firebase 项目(如果使用 Cloud Storage for Firebase)。
您必须重新发布应用才能更新模型 无需重新发布应用即可推送模型更新
没有内置的 A/B 测试 使用 Firebase Remote Config 进行 A/B 测试

试试看

准备工作

  1. 在 Podfile 中添加机器学习套件库:

    pod 'GoogleMLKit/ObjectDetectionCustom', '8.0.0'
    
  2. 安装或更新项目的 Pod 之后,请使用 Xcode 项目的 .xcworkspace 来打开项目。Xcode 版本 13.2.1 或更高版本支持机器学习套件。

  3. 如果您想使用 Cloud Storage for Firebase 下载模型,请务必将 Firebase 添加到您的 iOS 项目(如果您尚未添加)。捆绑模型时不需要这样做。

1. 加载模型

配置本地模型来源

如需将模型与您的应用捆绑在一起,请执行以下操作:

  1. 将模型文件(通常以 .tflite.lite 结尾)复制到您的 Xcode 项目,并在执行此操作时注意选择 Copy bundle resources。模型文件将包含在 app bundle 中,并提供给机器学习套件使用。

  2. 创建一个 LocalModel 对象,指定模型文件的路径:

    Swift

    let localModel = LocalModel(path: localModelFilePath)

    Objective-C

    MLKLocalModel *localModel =
        [[MLKLocalModel alloc] initWithPath:localModelFilePath];

配置远程托管的模型来源

如需使用远程托管的模型,您必须使用自己的应用逻辑将模型文件下载到设备的本地存储空间,然后将其作为本地模型加载。我们建议使用 Cloud Storage for Firebase 来托管模型。如需了解实现详情,请参阅 Firebase ML 到 Cloud Storage 迁移指南

2. 配置对象检测器

配置模型来源后,使用 CustomObjectDetectorOptions 对象为您的使用场景配置对象检测器。您可以更改以下设置:

对象检测器设置
检测模式 STREAM_MODE(默认)| SINGLE_IMAGE_MODE

STREAM_MODE(默认)下,对象检测器以低延迟高速运行,但在前几次调用检测器时可能会产生不完整的结果(例如未指定的边界框或类别标签)。此外,在 STREAM_MODE 下,检测器会为对象分配跟踪 ID,您可以使用该 ID 来跨帧跟踪对象。如果您想要跟踪对象,或者对延迟有要求(例如在实时处理视频流时),请使用此模式。

SINGLE_IMAGE_MODE 中,对象检测器会在确定对象的边界框后返回结果。如果您还启用了分类,则在边界框和类别标签都可用后,它会返回结果。因此,此模式下的检测延迟可能较高。此外,在 SINGLE_IMAGE_MODE 下,不会分配跟踪 ID。如果不计较延迟高低,且不想处理不完整的结果,请使用此模式。

检测和跟踪多个对象 false(默认)| true

是检测和跟踪最多五个对象,还是仅检测和跟踪最突出的对象(默认)。

对对象进行分类 false(默认)| true

是否使用提供的自定义敏感类别模型对检测到的对象进行分类。如需使用自定义分类模型,您需要将此属性设置为 true

分类置信度阈值

检测到的标签的最低置信度得分。如果未设置,系统将使用模型元数据指定的任何分类器阈值。如果模型不包含任何元数据,或者元数据未指定分类器阈值,系统将使用默认阈值 0.0。

每个对象的标签数上限

检测器将返回的每个对象的最大标签数。如果未设置,则系统会使用默认值 10。

如果您只有本地捆绑的模型,只需根据 LocalModel 对象创建一个对象检测器即可:

Swift

let options = CustomObjectDetectorOptions(localModel: localModel)
options.detectorMode = .singleImage
options.shouldEnableClassification = true
options.shouldEnableMultipleObjects = true
options.classificationConfidenceThreshold = NSNumber(value: 0.5)
options.maxPerObjectLabelCount = 3

Objective-C

MLKCustomObjectDetectorOptions *options =
    [[MLKCustomObjectDetectorOptions alloc] initWithLocalModel:localModel];
options.detectorMode = MLKObjectDetectorModeSingleImage;
options.shouldEnableClassification = YES;
options.shouldEnableMultipleObjects = YES;
options.classificationConfidenceThreshold = @(0.5);
options.maxPerObjectLabelCount = 3;

如果您使用的是远程托管的模型,则必须在运行之前检查该模型是否已下载。

虽然您只需在运行对象检测器之前确认这一点,但如果您同时拥有远程托管模型和本地捆绑模型,那么在实例化 ObjectDetector 时执行此检查可能是有意义的:如果远程模型已下载,则从该模型创建检测器,否则从本地模型创建检测器。

Swift

// Path where your download logic saves the model
let documentDirectory = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
let localModelURL = documentDirectory.appendingPathComponent("my_remote_model.tflite")

let model: LocalModel
if FileManager.default.fileExists(atPath: localModelURL.path) {
  // Use the downloaded model
  model = LocalModel(path: localModelURL.path)
} else {
  // Fall back to bundled model
  guard let bundledModelPath = Bundle.main.path(forResource: "model", ofType: "tflite") else { return }
  model = LocalModel(path: bundledModelPath)
}

let options = CustomObjectDetectorOptions(localModel: model)
options.detectorMode = .singleImage
options.shouldEnableClassification = true
options.shouldEnableMultipleObjects = true
options.classificationConfidenceThreshold = NSNumber(value: 0.5)
options.maxPerObjectLabelCount = 3
let objectDetector = ObjectDetector.objectDetector(options: options)

Objective-C

NSString *documentsDirectory = [NSSearchPathForDirectoriesInDomains(NSDocumentDirectory, NSUserDomainMask, YES) firstObject];
NSString *localModelPath = [documentsDirectory stringByAppendingPathComponent:@"my_remote_model.tflite"];

MLKLocalModel *model;
if ([NSFileManager.defaultManager fileExistsAtPath:localModelPath]) {
  // Use the downloaded model
  model = [[MLKLocalModel alloc] initWithPath:localModelPath];
} else {
  // Fall back to bundled model
  NSString *bundledModelPath = [NSBundle.mainBundle pathForResource:@"model" ofType:@"tflite"];
  model = [[MLKLocalModel alloc] initWithPath:bundledModelPath];
}

MLKCustomObjectDetectorOptions *options = [[MLKCustomObjectDetectorOptions alloc] initWithLocalModel:model];
options.detectorMode = MLKObjectDetectorModeSingleImage;
options.shouldEnableClassification = YES;
options.shouldEnableMultipleObjects = YES;
options.classificationConfidenceThreshold = @(0.5);
options.maxPerObjectLabelCount = 3;
MLKObjectDetector *objectDetector = [MLKObjectDetector objectDetectorWithOptions:options];

如果您只有远程托管的模型,则应停用与模型相关的功能(例如灰显或隐藏部分界面),直到您确认模型已下载。

Swift

let documentDirectory = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
let localModelURL = documentDirectory.appendingPathComponent("my_remote_model.tflite")
if FileManager.default.fileExists(atPath: localModelURL.path) {
  // Model is already cached, initialize immediately
  self.initializeDetector(with: localModelURL)
} else {
  // Model is not yet available, show loading UI and start download
  self.showLoadingUI()
  let storage = Storage.storage()
  let modelRef = storage.reference(forURL: "gs://YOUR_BUCKET/path/to/model.tflite")
  modelRef.write(toFile: localModelURL) { url, error in
    self.hideLoadingUI()
    if let error = error {
      // Handle download error
      self.showErrorUI()
    } else if let modelURL = url {
      // Download success, initialize detector
      self.initializeDetector(with: modelURL)
    }
  }
}

func initializeDetector(with modelURL: URL) {
  let localModel = LocalModel(path: modelURL.path)
  let options = CustomObjectDetectorOptions(localModel: localModel)
  options.detectorMode = .singleImage
  options.shouldEnableClassification = true
  options.shouldEnableMultipleObjects = true
  self.objectDetector = ObjectDetector.objectDetector(options: options)
  // Enable ML features in UI
  self.enableMLFeatures()
}

Objective-C

NSString *documentsDirectory = [NSSearchPathForDirectoriesInDomains(NSDocumentDirectory, NSUserDomainMask, YES) firstObject];
NSString *localModelPath = [documentsDirectory stringByAppendingPathComponent:@"my_remote_model.tflite"];
NSURL *localModelURL = [NSURL fileURLWithPath:localModelPath];

if ([NSFileManager.defaultManager fileExistsAtPath:localModelPath]) {
  // Model is already cached, initialize immediately
  [self initializeDetectorWithURL:localModelURL];
} else {
  // Model is not yet available, show loading UI and start download
  [self showLoadingUI];

  FIRStorage *storage = [FIRStorage storage];
  FIRStorageReference *modelRef = [storage referenceForURL:@"gs://YOUR_BUCKET/path/to/model.tflite"];

  [modelRef writeToFile:localModelURL
             completion:^(NSURL * _Nullable URL, NSError * _Nullable error) {
               [self hideLoadingUI];
               if (error != nil) {
                 // Handle download error
                 [self showErrorUI];
               } else {
                 // Download success, initialize detector
                 [self initializeDetectorWithURL:URL];
               }
             }];
}

- (void)initializeDetectorWithURL:(NSURL *)modelURL {
  MLKLocalModel *localModel = [[MLKLocalModel alloc] initWithPath:modelURL.path];
  MLKCustomObjectDetectorOptions *options = [[MLKCustomObjectDetectorOptions alloc] initWithLocalModel:localModel];
  options.detectorMode = MLKObjectDetectorModeSingleImage;
  options.shouldEnableClassification = YES;
  options.shouldEnableMultipleObjects = YES;
  self.objectDetector = [MLKObjectDetector objectDetectorWithOptions:options];

  // Enable ML features in UI
  [self enableMLFeatures];
}

对象检测和跟踪 API 针对以下两个核心使用场景进行了优化:

  • 实时检测和跟踪相机取景器中最突出的对象。
  • 检测静态图片中的多个对象。

如需为这些使用场景配置 API,请运行以下代码:

Swift

// Live detection and tracking
let options = CustomObjectDetectorOptions(localModel: localModel)
options.shouldEnableClassification = true
options.maxPerObjectLabelCount = 3

// Multiple object detection in static images
let options = CustomObjectDetectorOptions(localModel: localModel)
options.detectorMode = .singleImage
options.shouldEnableMultipleObjects = true
options.shouldEnableClassification = true
options.maxPerObjectLabelCount = 3

Objective-C

// Live detection and tracking
MLKCustomObjectDetectorOptions *options =
    [[MLKCustomObjectDetectorOptions alloc] initWithLocalModel:localModel];
options.shouldEnableClassification = YES;
options.maxPerObjectLabelCount = 3;

// Multiple object detection in static images
MLKCustomObjectDetectorOptions *options =
    [[MLKCustomObjectDetectorOptions alloc] initWithLocalModel:localModel];
options.detectorMode = MLKObjectDetectorModeSingleImage;
options.shouldEnableMultipleObjects = YES;
options.shouldEnableClassification = YES;
options.maxPerObjectLabelCount = 3;

3. 准备输入图片

使用 UIImageCMSampleBuffer 创建一个 VisionImage 对象。

如果您使用的是 UIImage,请按以下步骤操作:

  • 使用 UIImage 创建一个 VisionImage 对象。请务必指定正确的 .orientation

    Swift

    let image = VisionImage(image: UIImage)
    visionImage.orientation = image.imageOrientation

    Objective-C

    MLKVisionImage *visionImage = [[MLKVisionImage alloc] initWithImage:image];
    visionImage.orientation = image.imageOrientation;

如果您使用的是 CMSampleBuffer,请按以下步骤操作:

  • 指定 CMSampleBuffer 中所含图片数据的方向。

    如需获取图片方向,请运行以下命令:

    Swift

    func imageOrientation(
      deviceOrientation: UIDeviceOrientation,
      cameraPosition: AVCaptureDevice.Position
    ) -> UIImage.Orientation {
      switch deviceOrientation {
      case .portrait:
        return cameraPosition == .front ? .leftMirrored : .right
      case .landscapeLeft:
        return cameraPosition == .front ? .downMirrored : .up
      case .portraitUpsideDown:
        return cameraPosition == .front ? .rightMirrored : .left
      case .landscapeRight:
        return cameraPosition == .front ? .upMirrored : .down
      case .faceDown, .faceUp, .unknown:
        return .up
      }
    }
          

    Objective-C

    - (UIImageOrientation)
      imageOrientationFromDeviceOrientation:(UIDeviceOrientation)deviceOrientation
                             cameraPosition:(AVCaptureDevicePosition)cameraPosition {
      switch (deviceOrientation) {
        case UIDeviceOrientationPortrait:
          return cameraPosition == AVCaptureDevicePositionFront ? UIImageOrientationLeftMirrored
                                                                : UIImageOrientationRight;
    
        case UIDeviceOrientationLandscapeLeft:
          return cameraPosition == AVCaptureDevicePositionFront ? UIImageOrientationDownMirrored
                                                                : UIImageOrientationUp;
        case UIDeviceOrientationPortraitUpsideDown:
          return cameraPosition == AVCaptureDevicePositionFront ? UIImageOrientationRightMirrored
                                                                : UIImageOrientationLeft;
        case UIDeviceOrientationLandscapeRight:
          return cameraPosition == AVCaptureDevicePositionFront ? UIImageOrientationUpMirrored
                                                                : UIImageOrientationDown;
        case UIDeviceOrientationUnknown:
        case UIDeviceOrientationFaceUp:
        case UIDeviceOrientationFaceDown:
          return UIImageOrientationUp;
      }
    }
          
  • 使用 CMSampleBuffer 对象和方向创建一个 VisionImage 对象:

    Swift

    let image = VisionImage(buffer: sampleBuffer)
    image.orientation = imageOrientation(
      deviceOrientation: UIDevice.current.orientation,
      cameraPosition: cameraPosition)

    Objective-C

     MLKVisionImage *image = [[MLKVisionImage alloc] initWithBuffer:sampleBuffer];
     image.orientation =
       [self imageOrientationFromDeviceOrientation:UIDevice.currentDevice.orientation
                                    cameraPosition:cameraPosition];

4. 创建并运行对象检测器

  1. 创建新的对象检测器:

    Swift

    let objectDetector = ObjectDetector.objectDetector(options: options)

    Objective-C

    MLKObjectDetector *objectDetector = [MLKObjectDetector objectDetectorWithOptions:options];
  2. 然后,使用检测器:

    异步:

    Swift

    objectDetector.process(image) { objects, error in
        guard error == nil, let objects = objects, !objects.isEmpty else {
            // Handle the error.
            return
        }
        // Show results.
    }

    Objective-C

    [objectDetector
        processImage:image
          completion:^(NSArray *_Nullable objects,
                       NSError *_Nullable error) {
            if (objects.count == 0) {
                // Handle the error.
                return;
            }
            // Show results.
         }];

    同步:

    Swift

    var objects: [Object]
    do {
        objects = try objectDetector.results(in: image)
    } catch let error {
        // Handle the error.
        return
    }
    // Show results.

    Objective-C

    NSError *error;
    NSArray *objects =
        [objectDetector resultsInImage:image error:&error];
    // Show results or handle the error.

5. 获取已加标签的对象的相关信息

如果对图像处理器的调用成功完成,则系统会将 Object 列表传递给完成处理程序或返回该列表,具体取决于您调用的是异步方法还是同步方法。

每个 Object 包含以下属性:

frame 一个 CGRect,指示图片中对象的位置。
trackingID 一个整数,用于跨图片识别对象;在 SINGLE_IMAGE_MODE 下为 `nil`。
labels
label.text 标签的文本说明。仅当 LiteRT 模型的元数据包含标签说明时返回。
label.index 相应标签在分类器支持的所有标签中的索引。
label.confidence 对象分类的置信度值。

Swift

// objects contains one item if multiple object detection wasn't enabled.
for object in objects {
  let frame = object.frame
  let trackingID = object.trackingID
  let description = object.labels.enumerated().map { (index, label) in
    "Label \(index): \(label.text), \(label.confidence), \(label.index)"
  }.joined(separator: "\n")
}

Objective-C

// The list of detected objects contains one item if multiple object detection
// wasn't enabled.
for (MLKObject *object in objects) {
  CGRect frame = object.frame;
  NSNumber *trackingID = object.trackingID;
  for (MLKObjectLabel *label in object.labels) {
    NSString *labelString =
        [NSString stringWithFormat:@"%@, %f, %lu",
                                   label.text,
                                   label.confidence,
                                   (unsigned long)label.index];
  }
}

确保出色的用户体验

如需获得最佳用户体验,请在您的应用中遵循以下准则:

  • 对象检测成功与否取决于对象的视觉复杂性。为了能够被检测到,具有较少视觉特征的对象可能需要占据待检测图片的较大部分区域。您应为用户提供有关捕获输入的指导,该输入应适用于您要检测的对象类型。
  • 使用分类时,如果您要检测不完全归于受支持类别的对象,请对未知对象执行特殊处理。

另请参阅机器学习套件 Material Design 展示应用适用于机器学习所支持功能集的 Material Design 模式

提高性能

如果要在实时应用中使用对象检测,请遵循以下准则以实现最佳帧速率:

  • 在实时应用中使用流式传输模式时,请勿使用多个对象检测,因为大多数设备无法产生足够高的帧速率。

  • 对于处理视频帧,请使用检测器的 results(in:) 同步 API。从 AVCaptureVideoDataOutputSampleBufferDelegate captureOutput(_, didOutput:from:) 函数调用此方法,以同步获取给定视频帧的结果。将 AVCaptureVideoDataOutput alwaysDiscardsLateVideoFrames 保持为 true,以限制对检测器的调用。如果在检测器运行时有新的视频帧可用,则会丢弃该帧。
  • 如果要将检测器的输出作为图形叠加在输入图片上,请先从机器学习套件获取结果,然后在一个步骤中完成图片的呈现和叠加。采用这一方法,每个处理后的输入帧只需在显示表面呈现一次。如需查看示例,请参阅机器学习套件快速入门示例中的 updatePreviewOverlayViewWithLastFrame