fix(server): use bigrams for cjk (#24285)
* use bigrams for cjk * update sql * linting * actually migrate ocr * fix backwards test * use array * tweakspull/24322/head
parent
d8ca210641
commit
95c29a8aea
|
|
@ -45,12 +45,12 @@ export class OcrRepository {
|
||||||
textScore: DummyValue.NUMBER,
|
textScore: DummyValue.NUMBER,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
DummyValue.STRING,
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
upsert(assetId: string, ocrDataList: Insertable<AssetOcrTable>[]) {
|
upsert(assetId: string, ocrDataList: Insertable<AssetOcrTable>[], searchText: string) {
|
||||||
let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId));
|
let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId));
|
||||||
if (ocrDataList.length > 0) {
|
if (ocrDataList.length > 0) {
|
||||||
const searchText = ocrDataList.map((item) => item.text.trim()).join(' ');
|
|
||||||
(query as any) = query
|
(query as any) = query
|
||||||
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
|
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
|
||||||
.with('inserted_search', (db) =>
|
.with('inserted_search', (db) =>
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
import { Kysely, sql } from 'kysely';
|
||||||
|
import { tokenizeForSearch } from 'src/utils/database';
|
||||||
|
|
||||||
|
export async function up(db: Kysely<any>): Promise<void> {
|
||||||
|
await sql`truncate ${sql.table('ocr_search')}`.execute(db);
|
||||||
|
const batch = [];
|
||||||
|
for await (const { assetId, text } of db
|
||||||
|
.selectFrom('asset_ocr')
|
||||||
|
.select(['assetId', sql<string>`string_agg(text, ' ')`.as('text')])
|
||||||
|
.groupBy('assetId')
|
||||||
|
.stream()) {
|
||||||
|
batch.push({ assetId, text: tokenizeForSearch(text) });
|
||||||
|
if (batch.length >= 5000) {
|
||||||
|
await db.insertInto('ocr_search').values(batch).execute();
|
||||||
|
batch.length = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (batch.length > 0) {
|
||||||
|
await db.insertInto('ocr_search').values(batch).execute();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function down(): Promise<void> {}
|
||||||
|
|
@ -12,8 +12,21 @@ describe(OcrService.name, () => {
|
||||||
({ sut, mocks } = newTestService(OcrService));
|
({ sut, mocks } = newTestService(OcrService));
|
||||||
|
|
||||||
mocks.config.getWorker.mockReturnValue(ImmichWorker.Microservices);
|
mocks.config.getWorker.mockReturnValue(ImmichWorker.Microservices);
|
||||||
|
mocks.assetJob.getForOcr.mockResolvedValue({
|
||||||
|
visibility: AssetVisibility.Timeline,
|
||||||
|
previewFile: assetStub.image.files[1].path,
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const mockOcrResult = (...texts: string[]) => {
|
||||||
|
mocks.machineLearning.ocr.mockResolvedValue({
|
||||||
|
box: texts.flatMap((_, i) => Array.from({ length: 8 }, (_, j) => i * 10 + j)),
|
||||||
|
boxScore: texts.map(() => 0.9),
|
||||||
|
text: texts,
|
||||||
|
textScore: texts.map(() => 0.95),
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
it('should work', () => {
|
it('should work', () => {
|
||||||
expect(sut).toBeDefined();
|
expect(sut).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
@ -72,10 +85,6 @@ describe(OcrService.name, () => {
|
||||||
text: ['One Two Three', 'Four Five'],
|
text: ['One Two Three', 'Four Five'],
|
||||||
textScore: [0.95, 0.85],
|
textScore: [0.95, 0.85],
|
||||||
});
|
});
|
||||||
mocks.assetJob.getForOcr.mockResolvedValue({
|
|
||||||
visibility: AssetVisibility.Timeline,
|
|
||||||
previewFile: assetStub.image.files[1].path,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);
|
expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);
|
||||||
|
|
||||||
|
|
@ -88,36 +97,40 @@ describe(OcrService.name, () => {
|
||||||
maxResolution: 736,
|
maxResolution: 736,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, [
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(
|
||||||
{
|
assetStub.image.id,
|
||||||
assetId: assetStub.image.id,
|
[
|
||||||
boxScore: 0.9,
|
{
|
||||||
text: 'One Two Three',
|
assetId: assetStub.image.id,
|
||||||
textScore: 0.95,
|
boxScore: 0.9,
|
||||||
x1: 10,
|
text: 'One Two Three',
|
||||||
y1: 20,
|
textScore: 0.95,
|
||||||
x2: 30,
|
x1: 10,
|
||||||
y2: 40,
|
y1: 20,
|
||||||
x3: 50,
|
x2: 30,
|
||||||
y3: 60,
|
y2: 40,
|
||||||
x4: 70,
|
x3: 50,
|
||||||
y4: 80,
|
y3: 60,
|
||||||
},
|
x4: 70,
|
||||||
{
|
y4: 80,
|
||||||
assetId: assetStub.image.id,
|
},
|
||||||
boxScore: 0.8,
|
{
|
||||||
text: 'Four Five',
|
assetId: assetStub.image.id,
|
||||||
textScore: 0.85,
|
boxScore: 0.8,
|
||||||
x1: 90,
|
text: 'Four Five',
|
||||||
y1: 100,
|
textScore: 0.85,
|
||||||
x2: 110,
|
x1: 90,
|
||||||
y2: 120,
|
y1: 100,
|
||||||
x3: 130,
|
x2: 110,
|
||||||
y3: 140,
|
y2: 120,
|
||||||
x4: 150,
|
x3: 130,
|
||||||
y4: 160,
|
y3: 140,
|
||||||
},
|
x4: 150,
|
||||||
]);
|
y4: 160,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'One Two Three Four Five',
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should apply config settings', async () => {
|
it('should apply config settings', async () => {
|
||||||
|
|
@ -133,11 +146,7 @@ describe(OcrService.name, () => {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
mocks.machineLearning.ocr.mockResolvedValue({ box: [], boxScore: [], text: [], textScore: [] });
|
mockOcrResult();
|
||||||
mocks.assetJob.getForOcr.mockResolvedValue({
|
|
||||||
visibility: AssetVisibility.Timeline,
|
|
||||||
previewFile: assetStub.image.files[1].path,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);
|
expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);
|
||||||
|
|
||||||
|
|
@ -150,7 +159,7 @@ describe(OcrService.name, () => {
|
||||||
maxResolution: 1500,
|
maxResolution: 1500,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, []);
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, [], '');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should skip invisible assets', async () => {
|
it('should skip invisible assets', async () => {
|
||||||
|
|
@ -173,5 +182,83 @@ describe(OcrService.name, () => {
|
||||||
expect(mocks.machineLearning.ocr).not.toHaveBeenCalled();
|
expect(mocks.machineLearning.ocr).not.toHaveBeenCalled();
|
||||||
expect(mocks.ocr.upsert).not.toHaveBeenCalled();
|
expect(mocks.ocr.upsert).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('search tokenization', () => {
|
||||||
|
it('should generate bigrams for Chinese text', async () => {
|
||||||
|
mockOcrResult('機器學習');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 器學 學習');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should generate bigrams for Japanese text', async () => {
|
||||||
|
mockOcrResult('テスト');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'テス スト');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should generate bigrams for Korean text', async () => {
|
||||||
|
mockOcrResult('한국어');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '한국 국어');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should pass through Latin text unchanged', async () => {
|
||||||
|
mockOcrResult('Hello World');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'Hello World');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle mixed CJK and Latin text', async () => {
|
||||||
|
mockOcrResult('機器學習Model');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 器學 學習 Model');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle year followed by CJK', async () => {
|
||||||
|
mockOcrResult('2024年レポート');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(
|
||||||
|
assetStub.image.id,
|
||||||
|
expect.any(Array),
|
||||||
|
'2024 年レ レポ ポー ート',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should join multiple OCR boxes', async () => {
|
||||||
|
mockOcrResult('機器', 'Learning');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), '機器 Learning');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should normalize whitespace', async () => {
|
||||||
|
mockOcrResult(' Hello World ');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'Hello World');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should keep single CJK characters', async () => {
|
||||||
|
mockOcrResult('A', '中', 'B');
|
||||||
|
|
||||||
|
await sut.handleOcr({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, expect.any(Array), 'A 中 B');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum';
|
||||||
import { OCR } from 'src/repositories/machine-learning.repository';
|
import { OCR } from 'src/repositories/machine-learning.repository';
|
||||||
import { BaseService } from 'src/services/base.service';
|
import { BaseService } from 'src/services/base.service';
|
||||||
import { JobItem, JobOf } from 'src/types';
|
import { JobItem, JobOf } from 'src/types';
|
||||||
|
import { tokenizeForSearch } from 'src/utils/database';
|
||||||
import { isOcrEnabled } from 'src/utils/misc';
|
import { isOcrEnabled } from 'src/utils/misc';
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
|
|
@ -53,8 +54,8 @@ export class OcrService extends BaseService {
|
||||||
}
|
}
|
||||||
|
|
||||||
const ocrResults = await this.machineLearningRepository.ocr(asset.previewFile, machineLearning.ocr);
|
const ocrResults = await this.machineLearningRepository.ocr(asset.previewFile, machineLearning.ocr);
|
||||||
|
const { ocrDataList, searchText } = this.parseOcrResults(id, ocrResults);
|
||||||
await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults));
|
await this.ocrRepository.upsert(id, ocrDataList, searchText);
|
||||||
|
|
||||||
await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() });
|
await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() });
|
||||||
|
|
||||||
|
|
@ -64,7 +65,9 @@ export class OcrService extends BaseService {
|
||||||
|
|
||||||
private parseOcrResults(id: string, { box, boxScore, text, textScore }: OCR) {
|
private parseOcrResults(id: string, { box, boxScore, text, textScore }: OCR) {
|
||||||
const ocrDataList = [];
|
const ocrDataList = [];
|
||||||
|
const searchTokens = [];
|
||||||
for (let i = 0; i < text.length; i++) {
|
for (let i = 0; i < text.length; i++) {
|
||||||
|
const rawText = text[i];
|
||||||
const boxOffset = i * 8;
|
const boxOffset = i * 8;
|
||||||
ocrDataList.push({
|
ocrDataList.push({
|
||||||
assetId: id,
|
assetId: id,
|
||||||
|
|
@ -78,9 +81,11 @@ export class OcrService extends BaseService {
|
||||||
y4: box[boxOffset + 7],
|
y4: box[boxOffset + 7],
|
||||||
boxScore: boxScore[i],
|
boxScore: boxScore[i],
|
||||||
textScore: textScore[i],
|
textScore: textScore[i],
|
||||||
text: text[i],
|
text: rawText,
|
||||||
});
|
});
|
||||||
|
searchTokens.push(...tokenizeForSearch(rawText));
|
||||||
}
|
}
|
||||||
return ocrDataList;
|
|
||||||
|
return { ocrDataList, searchText: searchTokens.join(' ') };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -306,6 +306,46 @@ export function withTagId<O>(qb: SelectQueryBuilder<DB, 'asset', O>, tagId: stri
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const isCJK = (c: number): boolean =>
|
||||||
|
(c >= 0x4e_00 && c <= 0x9f_ff) ||
|
||||||
|
(c >= 0xac_00 && c <= 0xd7_af) ||
|
||||||
|
(c >= 0x30_40 && c <= 0x30_9f) ||
|
||||||
|
(c >= 0x30_a0 && c <= 0x30_ff) ||
|
||||||
|
(c >= 0x34_00 && c <= 0x4d_bf);
|
||||||
|
|
||||||
|
export const tokenizeForSearch = (text: string): string[] => {
|
||||||
|
/* eslint-disable unicorn/prefer-code-point */
|
||||||
|
const tokens: string[] = [];
|
||||||
|
let i = 0;
|
||||||
|
while (i < text.length) {
|
||||||
|
const c = text.charCodeAt(i);
|
||||||
|
if (c <= 32) {
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const start = i;
|
||||||
|
if (isCJK(c)) {
|
||||||
|
while (i < text.length && isCJK(text.charCodeAt(i))) {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
if (i - start === 1) {
|
||||||
|
tokens.push(text[start]);
|
||||||
|
} else {
|
||||||
|
for (let k = start; k < i - 1; k++) {
|
||||||
|
tokens.push(text[k] + text[k + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
while (i < text.length && text.charCodeAt(i) > 32 && !isCJK(text.charCodeAt(i))) {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
tokens.push(text.slice(start, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
};
|
||||||
|
|
||||||
const joinDeduplicationPlugin = new DeduplicateJoinsPlugin();
|
const joinDeduplicationPlugin = new DeduplicateJoinsPlugin();
|
||||||
/** TODO: This should only be used for search-related queries, not as a general purpose query builder */
|
/** TODO: This should only be used for search-related queries, not as a general purpose query builder */
|
||||||
|
|
||||||
|
|
@ -391,7 +431,7 @@ export function searchAssetBuilder(kysely: Kysely<DB>, options: AssetSearchBuild
|
||||||
.$if(!!options.ocr, (qb) =>
|
.$if(!!options.ocr, (qb) =>
|
||||||
qb
|
qb
|
||||||
.innerJoin('ocr_search', 'asset.id', 'ocr_search.assetId')
|
.innerJoin('ocr_search', 'asset.id', 'ocr_search.assetId')
|
||||||
.where(() => sql`f_unaccent(ocr_search.text) %>> f_unaccent(${options.ocr!})`),
|
.where(() => sql`f_unaccent(ocr_search.text) %>> f_unaccent(${tokenizeForSearch(options.ocr!).join(' ')})`),
|
||||||
)
|
)
|
||||||
.$if(!!options.type, (qb) => qb.where('asset.type', '=', options.type!))
|
.$if(!!options.type, (qb) => qb.where('asset.type', '=', options.type!))
|
||||||
.$if(options.isFavorite !== undefined, (qb) => qb.where('asset.isFavorite', '=', options.isFavorite!))
|
.$if(options.isFavorite !== undefined, (qb) => qb.where('asset.isFavorite', '=', options.isFavorite!))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue