diff --git a/machine-learning/test_main.py b/machine-learning/test_main.py index 5145be0045..be574c6397 100644 --- a/machine-learning/test_main.py +++ b/machine-learning/test_main.py @@ -816,6 +816,10 @@ class TestFaceRecognition: def test_recognition(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None: mocker.patch.object(FaceRecognizer, "load") + mocker.patch( + "immich_ml.models.facial_recognition.recognition.ort.get_available_providers", + return_value=["CPUExecutionProvider"], + ) face_recognizer = FaceRecognizer("buffalo_s", min_score=0.0, cache_dir="test_cache") num_faces = 2 @@ -860,6 +864,10 @@ class TestFaceRecognition: ) mocker.patch("immich_ml.models.base.InferenceModel.download") mocker.patch("immich_ml.models.facial_recognition.recognition.ArcFaceONNX") + mocker.patch( + "immich_ml.models.facial_recognition.recognition.ort.get_available_providers", + return_value=["CPUExecutionProvider"], + ) ort_session.return_value.get_inputs.return_value = [SimpleNamespace(name="input.1", shape=(1, 3, 224, 224))] ort_session.return_value.get_outputs.return_value = [SimpleNamespace(name="output.1", shape=(1, 800))] path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx" @@ -894,6 +902,10 @@ class TestFaceRecognition: ) mocker.patch("immich_ml.models.base.InferenceModel.download") mocker.patch("immich_ml.models.facial_recognition.recognition.ArcFaceONNX") + mocker.patch( + "immich_ml.models.facial_recognition.recognition.ort.get_available_providers", + return_value=["CPUExecutionProvider"], + ) path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx" inputs = [SimpleNamespace(name="input.1", shape=("batch", 3, 224, 224))] @@ -996,6 +1008,10 @@ class TestFaceRecognition: def test_ignore_other_custom_max_batch_size(self, mocker: MockerFixture) -> None: mocker.patch.object(settings, "max_batch_size", MaxBatchSize(ocr=2)) + mocker.patch( + "immich_ml.models.facial_recognition.recognition.ort.get_available_providers", + return_value=["CPUExecutionProvider"], + ) recognizer = FaceRecognizer("buffalo_l", cache_dir="test_cache")