diff --git a/machine-learning/immich_ml/models/ocr/recognition.py b/machine-learning/immich_ml/models/ocr/recognition.py index 6408e4818f..94f40c9285 100644 --- a/machine-learning/immich_ml/models/ocr/recognition.py +++ b/machine-learning/immich_ml/models/ocr/recognition.py @@ -64,6 +64,7 @@ class TextRecognizer(InferenceModel): rec_batch_num=max_batch_size if max_batch_size else 6, rec_img_shape=(3, 48, 320), lang_type=self.language, + model_root_dir=self.cache_dir, ) ) return session diff --git a/machine-learning/test_main.py b/machine-learning/test_main.py index b281c0d417..5145be0045 100644 --- a/machine-learning/test_main.py +++ b/machine-learning/test_main.py @@ -1028,7 +1028,12 @@ class TestOcr: text_recognizer.load() rapid_recognizer.assert_called_once_with( - OcrOptions(session=ort_session.return_value, rec_batch_num=6, rec_img_shape=(3, 48, 320)) + OcrOptions( + session=ort_session.return_value, + rec_batch_num=6, + rec_img_shape=(3, 48, 320), + model_root_dir=text_recognizer.cache_dir, + ) ) def test_set_custom_max_batch_size(self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture) -> None: @@ -1041,7 +1046,12 @@ class TestOcr: text_recognizer.load() rapid_recognizer.assert_called_once_with( - OcrOptions(session=ort_session.return_value, rec_batch_num=4, rec_img_shape=(3, 48, 320)) + OcrOptions( + session=ort_session.return_value, + rec_batch_num=4, + rec_img_shape=(3, 48, 320), + model_root_dir=text_recognizer.cache_dir, + ) ) def test_ignore_other_custom_max_batch_size( @@ -1056,7 +1066,12 @@ class TestOcr: text_recognizer.load() rapid_recognizer.assert_called_once_with( - OcrOptions(session=ort_session.return_value, rec_batch_num=6, rec_img_shape=(3, 48, 320)) + OcrOptions( + session=ort_session.return_value, + rec_batch_num=6, + rec_img_shape=(3, 48, 320), + model_root_dir=text_recognizer.cache_dir, + ) )