diff --git a/server/src/repositories/ocr.repository.ts b/server/src/repositories/ocr.repository.ts index 1da9a96ec5..a39f0d368c 100644 --- a/server/src/repositories/ocr.repository.ts +++ b/server/src/repositories/ocr.repository.ts @@ -45,12 +45,12 @@ export class OcrRepository { textScore: DummyValue.NUMBER, }, ], + DummyValue.STRING, ], }) - upsert(assetId: string, ocrDataList: Insertable[]) { + upsert(assetId: string, ocrDataList: Insertable[], searchText: string) { let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId)); if (ocrDataList.length > 0) { - const searchText = ocrDataList.map((item) => item.text.trim()).join(' '); (query as any) = query .with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList)) .with('inserted_search', (db) => diff --git a/server/src/schema/migrations/1764483051488-OCRBigramsForCJK.ts b/server/src/schema/migrations/1764483051488-OCRBigramsForCJK.ts new file mode 100644 index 0000000000..7b659396fe --- /dev/null +++ b/server/src/schema/migrations/1764483051488-OCRBigramsForCJK.ts @@ -0,0 +1,24 @@ +import { Kysely, sql } from 'kysely'; +import { tokenizeForSearch } from 'src/utils/database'; + +export async function up(db: Kysely): Promise { + await sql`truncate ${sql.table('ocr_search')}`.execute(db); + const batch = []; + for await (const { assetId, text } of db + .selectFrom('asset_ocr') + .select(['assetId', sql`string_agg(text, ' ')`.as('text')]) + .groupBy('assetId') + .stream()) { + batch.push({ assetId, text: tokenizeForSearch(text) }); + if (batch.length >= 5000) { + await db.insertInto('ocr_search').values(batch).execute(); + batch.length = 0; + } + } + + if (batch.length > 0) { + await db.insertInto('ocr_search').values(batch).execute(); + } +} + +export async function down(): Promise {} diff --git a/server/src/services/ocr.service.spec.ts b/server/src/services/ocr.service.spec.ts index 6eedba1a5f..404f423cac 100644 --- a/server/src/services/ocr.service.spec.ts +++ b/server/src/services/ocr.service.spec.ts @@ -12,8 +12,21 @@ describe(OcrService.name, () => { ({ sut, mocks } = newTestService(OcrService)); mocks.config.getWorker.mockReturnValue(ImmichWorker.Microservices); + mocks.assetJob.getForOcr.mockResolvedValue({ + visibility: AssetVisibility.Timeline, + previewFile: assetStub.image.files[1].path, + }); }); + const mockOcrResult = (...texts: string[]) => { + mocks.machineLearning.ocr.mockResolvedValue({ + box: texts.flatMap((_, i) => Array.from({ length: 8 }, (_, j) => i * 10 + j)), + boxScore: texts.map(() => 0.9), + text: texts, + textScore: texts.map(() => 0.95), + }); + }; + it('should work', () => { expect(sut).toBeDefined(); }); @@ -72,10 +85,6 @@ describe(OcrService.name, () => { text: ['One Two Three', 'Four Five'], textScore: [0.95, 0.85], }); - mocks.assetJob.getForOcr.mockResolvedValue({ - visibility: AssetVisibility.Timeline, - previewFile: assetStub.image.files[1].path, - }); expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success); @@ -88,36 +97,40 @@ describe(OcrService.name, () => { maxResolution: 736, }), ); - expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, [ - { - assetId: assetStub.image.id, - boxScore: 0.9, - text: 'One Two Three', - textScore: 0.95, - x1: 10, - y1: 20, - x2: 30, - y2: 40, - x3: 50, - y3: 60, - x4: 70, - y4: 80, - }, - { - assetId: assetStub.image.id, - boxScore: 0.8, - text: 'Four Five', - textScore: 0.85, - x1: 90, - y1: 100, - x2: 110, - y2: 120, - x3: 130, - y3: 140, - x4: 150, - y4: 160, - }, - ]); + expect(mocks.ocr.upsert).toHaveBeenCalledWith( + assetStub.image.id, + [ + { + assetId: assetStub.image.id, + boxScore: 0.9, + text: 'One Two Three', + textScore: 0.95, + x1: 10, + y1: 20, + x2: 30, + y2: 40, + x3: 50, + y3: 60, + x4: 70, + y4: 80, + }, + { + assetId: assetStub.image.id, + boxScore: 0.8, + text: 'Four Five', + textScore: 0.85, + x1: 90, + y1: 100, + x2: 110, + y2: 120, + x3: 130, + y3: 140, + x4: 150, + y4: 160, + }, + ], + 'One Two Three Four Five', + ); }); it('should apply config settings', async () => { @@ -133,11 +146,7 @@ describe(OcrService.name, () => { }, }, }); - mocks.machineLearning.ocr.mockResolvedValue({ box: [], boxScore: [], text: [], textScore: [] }); - mocks.assetJob.getForOcr.mockResolvedValue({ - visibility: AssetVisibility.Timeline, - previewFile: assetStub.image.files[1].path, - }); + mockOcrResult(); expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success); @@ -150,7 +159,7 @@ describe(OcrService.name, () => { maxResolution: 1500, }), ); - expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, []); + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, [], ''); }); it('should skip invisible assets', async () => { @@ -173,5 +182,83 @@ describe(OcrService.name, () => { expect(mocks.machineLearning.ocr).not.toHaveBeenCalled(); expect(mocks.ocr.upsert).not.toHaveBeenCalled(); }); + + describe('search tokenization', () => { + it('should generate bigrams for Chinese text', async () => { + mockOcrResult('機器學習'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 器學 學習'); + }); + + it('should generate bigrams for Japanese text', async () => { + mockOcrResult('テスト'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'テス スト'); + }); + + it('should generate bigrams for Korean text', async () => { + mockOcrResult('한국어'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '한국 국어'); + }); + + it('should pass through Latin text unchanged', async () => { + mockOcrResult('Hello World'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'Hello World'); + }); + + it('should handle mixed CJK and Latin text', async () => { + mockOcrResult('機器學習Model'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 器學 學習 Model'); + }); + + it('should handle year followed by CJK', async () => { + mockOcrResult('2024年レポート'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith( + assetStub.image.id, + expect.any(Array), + '2024 年レ レポ ポー ート', + ); + }); + + it('should join multiple OCR boxes', async () => { + mockOcrResult('機器', 'Learning'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 Learning'); + }); + + it('should normalize whitespace', async () => { + mockOcrResult(' Hello World '); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'Hello World'); + }); + + it('should keep single CJK characters', async () => { + mockOcrResult('A', '中', 'B'); + + await sut.handleOcr({ id: assetStub.image.id }); + + expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'A 中 B'); + }); + }); }); }); diff --git a/server/src/services/ocr.service.ts b/server/src/services/ocr.service.ts index cba57e5bc7..d92d399dba 100644 --- a/server/src/services/ocr.service.ts +++ b/server/src/services/ocr.service.ts @@ -5,6 +5,7 @@ import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum'; import { OCR } from 'src/repositories/machine-learning.repository'; import { BaseService } from 'src/services/base.service'; import { JobItem, JobOf } from 'src/types'; +import { tokenizeForSearch } from 'src/utils/database'; import { isOcrEnabled } from 'src/utils/misc'; @Injectable() @@ -53,8 +54,8 @@ export class OcrService extends BaseService { } const ocrResults = await this.machineLearningRepository.ocr(asset.previewFile, machineLearning.ocr); - - await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults)); + const { ocrDataList, searchText } = this.parseOcrResults(id, ocrResults); + await this.ocrRepository.upsert(id, ocrDataList, searchText); await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() }); @@ -64,7 +65,9 @@ export class OcrService extends BaseService { private parseOcrResults(id: string, { box, boxScore, text, textScore }: OCR) { const ocrDataList = []; + const searchTokens = []; for (let i = 0; i < text.length; i++) { + const rawText = text[i]; const boxOffset = i * 8; ocrDataList.push({ assetId: id, @@ -78,9 +81,11 @@ export class OcrService extends BaseService { y4: box[boxOffset + 7], boxScore: boxScore[i], textScore: textScore[i], - text: text[i], + text: rawText, }); + searchTokens.push(...tokenizeForSearch(rawText)); } - return ocrDataList; + + return { ocrDataList, searchText: searchTokens.join(' ') }; } } diff --git a/server/src/utils/database.ts b/server/src/utils/database.ts index 0cc3788f1a..f8dbd5e78c 100644 --- a/server/src/utils/database.ts +++ b/server/src/utils/database.ts @@ -306,6 +306,46 @@ export function withTagId(qb: SelectQueryBuilder, tagId: stri ); } +const isCJK = (c: number): boolean => + (c >= 0x4e_00 && c <= 0x9f_ff) || + (c >= 0xac_00 && c <= 0xd7_af) || + (c >= 0x30_40 && c <= 0x30_9f) || + (c >= 0x30_a0 && c <= 0x30_ff) || + (c >= 0x34_00 && c <= 0x4d_bf); + +export const tokenizeForSearch = (text: string): string[] => { + /* eslint-disable unicorn/prefer-code-point */ + const tokens: string[] = []; + let i = 0; + while (i < text.length) { + const c = text.charCodeAt(i); + if (c <= 32) { + i++; + continue; + } + + const start = i; + if (isCJK(c)) { + while (i < text.length && isCJK(text.charCodeAt(i))) { + i++; + } + if (i - start === 1) { + tokens.push(text[start]); + } else { + for (let k = start; k < i - 1; k++) { + tokens.push(text[k] + text[k + 1]); + } + } + } else { + while (i < text.length && text.charCodeAt(i) > 32 && !isCJK(text.charCodeAt(i))) { + i++; + } + tokens.push(text.slice(start, i)); + } + } + return tokens; +}; + const joinDeduplicationPlugin = new DeduplicateJoinsPlugin(); /** TODO: This should only be used for search-related queries, not as a general purpose query builder */ @@ -391,7 +431,7 @@ export function searchAssetBuilder(kysely: Kysely, options: AssetSearchBuild .$if(!!options.ocr, (qb) => qb .innerJoin('ocr_search', 'asset.id', 'ocr_search.assetId') - .where(() => sql`f_unaccent(ocr_search.text) %>> f_unaccent(${options.ocr!})`), + .where(() => sql`f_unaccent(ocr_search.text) %>> f_unaccent(${tokenizeForSearch(options.ocr!).join(' ')})`), ) .$if(!!options.type, (qb) => qb.where('asset.type', '=', options.type!)) .$if(options.isFavorite !== undefined, (qb) => qb.where('asset.isFavorite', '=', options.isFavorite!))