Initial check-in
[him-cellwriter] / src / recognize.c
1
2 /*
3
4 cellwriter -- a character recognition input method
5 Copyright (C) 2007 Michael Levin <risujin@risujin.org>
6
7 This program is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public License
9 as published by the Free Software Foundation; either version 2
10 of the License, or (at your option) any later version.
11
12 This program is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with this program; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
20
21 */
22
23 #include "config.h"
24 #include <stdlib.h>
25 #include <string.h>
26 #include <math.h>
27 #include <gtk/gtk.h>
28 #include "common.h"
29 #include "recognize.h"
30
31 /* preprocess.c */
32 int prep_examined;
33
34 void engine_prep(void);
35
36 /*
37         Engines
38 */
39
40 Engine engines[] = {
41
42         /* Preprocessor engine must run first */
43         { "Key-point distance", engine_prep, MAX_RANGE, TRUE, -1, 0, 0 },
44
45         /* Averaging engines */
46         { "Average distance", engine_average, MAX_RANGE, TRUE, -1, 0, 0 },
47         { "Average angle", NULL, MAX_RANGE, TRUE, 0, 0, 0 },
48
49 #ifndef DISABLE_WORDFREQ
50         /* Word frequency engine */
51         { "Word context", engine_wordfreq, MAX_RANGE / 3, FALSE, -1, 0, 0 },
52 #endif
53 };
54
55 static int engine_rating(const Sample *sample, int j)
56 /* Get the processed rating for engine j on a sample */
57 {
58         int value;
59
60         if (!engines[j].range || engines[j].max < 1)
61                 return 0;
62         value = ((int)sample->ratings[j] - engines[j].average) *
63                 engines[j].range / engines[j].max;
64         if (engines[j].scale >= 0)
65                 value = value * engines[j].scale / ENGINE_SCALE;
66         return value;
67 }
68
69 /*
70         Sample chain wrapper
71 */
72
73 typedef struct SampleLink {
74         Sample sample;
75         struct SampleLink *prev, *next;
76 } SampleLink;
77
78 static SampleLink *samplelink_root = NULL, *samplelink_iter = NULL;
79 static int current = 1;
80
81 static Sample *sample_new(void)
82 /* Allocate a link in the sample linked list */
83 {
84         SampleLink *link;
85
86         link = g_malloc0(sizeof (*link));
87         link->next = samplelink_root;
88         if (samplelink_root)
89                 samplelink_root->prev = link;
90         samplelink_root = link;
91         return &link->sample;
92 }
93
94 void sampleiter_reset(void)
95 /* Reset the sample linked list iterator */
96 {
97         samplelink_iter = samplelink_root;
98 }
99
100 Sample *sampleiter_next(void)
101 /* Get the next sample link from the sample linked list iterator */
102 {
103         SampleLink *link;
104
105         if (!samplelink_iter)
106                 return NULL;
107         link = samplelink_iter;
108         samplelink_iter = samplelink_iter->next;
109         return &link->sample;
110 }
111
112 int samples_loaded(void)
113 {
114         return samplelink_root != NULL;
115 }
116
117 /*
118         Samples
119 */
120
121 int samples_max = 5, no_latin_alpha = FALSE;
122
123 void clear_sample(Sample *sample)
124 /* Free stroke data associated with a sample and reset its parameters */
125 {
126         int i;
127
128         for (i = 0; i < sample->len; i++) {
129                 stroke_free(sample->strokes[i]);
130                 stroke_free(sample->roughs[i]);
131         }
132         memset(sample, 0, sizeof (*sample));
133 }
134
135 void copy_sample(Sample *dest, const Sample *src)
136 /* Copy a sample, cloing its strokes, overwriting dest */
137 {
138         int i;
139
140         *dest = *src;
141         for (i = 0; i < src->len; i++) {
142                 dest->strokes[i] = stroke_clone(src->strokes[i], FALSE);
143                 dest->roughs[i] = stroke_clone(src->roughs[i], FALSE);
144         }
145 }
146
147 static void process_gluable(const Sample *sample, int stroke_num)
148 /* Calculates the lowest distance between the start or end of one stroke and any
149    other point on each other stroke in the sample */
150 {
151         Point point;
152         Stroke *s1;
153         int i, start;
154
155         /* Dots cannot be glued */
156         s1 = sample->strokes[stroke_num];
157         memset(s1->gluable_start, -1, sizeof (s1->gluable_start));
158         memset(s1->gluable_end, -1, sizeof (s1->gluable_end));
159         if (s1->spread < DOT_SPREAD)
160                 return;
161
162         start = TRUE;
163 scan:
164         point = start ? s1->points[0] : s1->points[s1->len - 1];
165         for (i = 0; i < sample->len; i++) {
166                 Vec2 v;
167                 Stroke *s2;
168                 float dist, min = GLUE_DIST;
169                 int j;
170                 char gluable;
171
172                 s2 = sample->strokes[i];
173                 if (i == stroke_num || s2->spread < DOT_SPREAD)
174                         continue;
175
176                 /* Check the distance to the first point */
177                 vec2_set(&v, s2->points[0].x - point.x,
178                          s2->points[0].y - point.y);
179                 dist = vec2_mag(&v);
180                 if (dist < min)
181                         min = dist;
182
183                 /* Find the lowest distance from the glue point to any other
184                    point on the other stroke */
185                 for (j = 0; j < s2->len - 1; j++) {
186                         Vec2 l, w;
187                         double dist, mag, dot;
188
189                         /* Vector l is a unit vector from point j to j + 1 */
190                         vec2_set(&l, s2->points[j].x - s2->points[j + 1].x,
191                                  s2->points[j].y - s2->points[j + 1].y);
192                         mag = vec2_norm(&l, &l);
193
194                         /* Vector w is a vector from point j to our point */
195                         vec2_set(&w, s2->points[j].x - point.x,
196                                  s2->points[j].y - point.y);
197
198                         /* For points that are not in between a segment,
199                            get the distance from the points themselves,
200                            otherwise get the distance from the segment line */
201                         dot = vec2_dot(&l, &w);
202                         if (dot < 0. || dot > mag) {
203                                 vec2_set(&v, s2->points[j + 1].x - point.x,
204                                          s2->points[j + 1].y - point.y);
205                                 dist = vec2_mag(&v);
206                         } else {
207                                 dist = vec2_cross(&w, &l);
208                                 if (dist < 0)
209                                         dist = -dist;
210                         }
211                         if (dist < min)
212                                 min = dist;
213                 }
214                 gluable = min * GLUABLE_MAX / GLUE_DIST;
215                 if (start)
216                         s1->gluable_start[i] = gluable;
217                 else
218                         s1->gluable_end[i] = gluable;
219         }
220         if (start) {
221                 start = FALSE;
222                 goto scan;
223         }
224 }
225
226 void process_sample(Sample *sample)
227 /* Generate cached properties of a sample */
228 {
229         int i;
230         float distance;
231
232         if (sample->processed)
233                 return;
234         sample->processed = TRUE;
235
236         /* Make sure all strokes have been processed first */
237         for (i = 0; i < sample->len; i++)
238                 process_stroke(sample->strokes[i]);
239
240         /* Compute properties for each stroke */
241         vec2_set(&sample->center, 0., 0.);
242         for (i = 0, distance = 0.; i < sample->len; i++) {
243                 Vec2 v;
244                 Stroke *stroke;
245                 float weight;
246                 int points;
247
248                 stroke = sample->strokes[i];
249
250                 /* Add the stroke center to the center vector, weighted by
251                    length */
252                 vec2_copy(&v, &stroke->center);
253                 weight = stroke->spread < DOT_SPREAD ?
254                          DOT_SPREAD : stroke->distance;
255                 vec2_scale(&v, &v, weight);
256                 vec2_sum(&sample->center, &sample->center, &v);
257                 distance += weight;
258
259                 /* Get gluing distances */
260                 process_gluable(sample, i);
261
262                 /* Create a rough-sampled version */
263                 points = stroke->distance / ROUGH_RESOLUTION + 0.5;
264                 if (points < 4)
265                         points = 4;
266                 sample->roughs[i] = sample_stroke(NULL, stroke, points, points);
267         }
268         vec2_scale(&sample->center, &sample->center, 1.f / distance);
269         sample->distance = distance;
270 }
271
272 void center_samples(Vec2 *ac_to_bc, Sample *a, Sample *b)
273 /* Adjust for the difference between two sample centers */
274 {
275         vec2_sub(ac_to_bc, &b->center, &a->center);
276 }
277
278 int char_disabled(int ch)
279 /* Returns TRUE if a character is not renderable or is explicity disabled by
280    a setting (not counting disabled Unicode blocks) */
281 {
282         return (no_latin_alpha && ch >= unicode_blocks[0].start &&
283                 ch <= unicode_blocks[0].end && g_ascii_isalpha(ch)) ||
284                !g_unichar_isgraph(ch);
285 }
286
287 int sample_disqualified(const Sample *sample)
288 /* Check disqualification conditions for a sample during recognition.
289    The preprocessor engine must run before any calls to this or
290    disqualification will not work. */
291 {
292         if ((!ignore_stroke_num && sample->len != input->len) ||
293             !sample->enabled)
294                 return 1;
295         if (sample->disqualified)
296                 return 2;
297         if (char_disabled(sample->ch))
298                 return 3;
299         return 0;
300 }
301
302 int sample_valid(const Sample *sample, int used)
303 /* Check if this sample has changed since it was last referenced */
304 {
305         if (!sample || !used)
306                 return FALSE;
307         return sample->used == used;
308 }
309
310 static void sample_rating(Sample *sample)
311 /* Get the composite processed rating on a sample */
312 {
313         int i, rating;
314
315         if (!sample->ch || sample_disqualified(sample) ||
316             sample->penalty >= 1.f) {
317                 sample->rating = RATING_MIN;
318                 return;
319         }
320         for (i = 0, rating = 0; i < ENGINES; i++)
321                 rating += engine_rating(sample, i);
322         rating *= 1.f - sample->penalty;
323         if (rating > RATING_MAX)
324                 rating = RATING_MAX;
325         if (rating < RATING_MIN)
326                 rating = RATING_MIN;
327         sample->rating = rating;
328 }
329
330 void update_enabled_samples(void)
331 /* Run through the samples list and enable samples in enabled blocks */
332 {
333         Sample *sample;
334
335         sampleiter_reset();
336         while ((sample = sampleiter_next())) {
337                 UnicodeBlock *block;
338
339                 sample->enabled = FALSE;
340                 if (!sample->ch)
341                         continue;
342                 block = unicode_blocks;
343                 while (block->name) {
344                         if (sample->ch >= block->start &&
345                             sample->ch <= block->end) {
346                                 sample->enabled = block->enabled;
347                                 break;
348                         }
349                         block++;
350                 }
351         }
352 }
353
354 void promote_sample(Sample *sample)
355 /* Update usage counter for a sample */
356 {
357         sample->used = current++;
358 }
359
360 void demote_sample(Sample *sample)
361 /* Remove the sample from our set if we can */
362 {
363         if (char_trained(sample->ch) > 1)
364                 clear_sample(sample);
365         else
366                 sample->used = 1;
367 }
368
369 Stroke *transform_stroke(Sample *src, Transform *tfm, int i)
370 /* Create a new stroke by applying the transformation to the source */
371 {
372         Stroke *stroke;
373         int k, j;
374
375         stroke = stroke_new(0);
376         for (k = 0, j = 0; k < STROKES_MAX && j < src->len; k++)
377                 for (j = 0; j < src->len; j++)
378                         if (tfm->order[j] - 1 == i && tfm->glue[j] == k) {
379                                 glue_stroke(&stroke, src->strokes[j],
380                                             tfm->reverse[j]);
381                                 break;
382                         }
383         process_stroke(stroke);
384         return stroke;
385 }
386
387 /*
388         Recognition and training
389 */
390
391 Sample *input = NULL;
392 int strength_sum = 0;
393
394 static GTimer *timer;
395
396 void recognize_init(void)
397 {
398 #ifndef DISABLE_WORDFREQ
399         load_wordfreq();
400 #endif
401         timer = g_timer_new();
402 }
403
404 void recognize_sample(Sample *sample, Sample **alts, int num_alts)
405 {
406         gulong microsec;
407         int i, range, strength, msec;
408
409         g_timer_start(timer);
410         input = sample;
411         process_sample(input);
412
413         /* Clear ratings */
414         sampleiter_reset();
415         while ((sample = sampleiter_next())) {
416                 memset(sample->ratings, 0, sizeof (sample->ratings));
417                 sample->rating = 0;
418         }
419
420         /* Run engines */
421         for (i = 0, range = 0; i < ENGINES; i++) {
422                 int rated = 0;
423
424                 if (engines[i].func)
425                         engines[i].func();
426
427                 /* Compute average and maximum value */
428                 engines[i].max = 0;
429                 engines[i].average = 0;
430                 sampleiter_reset();
431                 while ((sample = sampleiter_next())) {
432                         int value = 0;
433
434                         if (!sample->ch)
435                                 continue;
436                         if (sample->ratings[i] > value)
437                                 value = sample->ratings[i];
438                         if (!value && engines[i].ignore_zeros)
439                                 continue;
440                         if (value > engines[i].max)
441                                 engines[i].max = value;
442                         engines[i].average += value;
443                         rated++;
444                 }
445                 if (!rated)
446                         continue;
447                 engines[i].average /= rated;
448                 if (engines[i].max > 0)
449                         range += engines[i].range;
450                 if (engines[i].max == engines[i].average) {
451                         engines[i].average = 0;
452                         continue;
453                 }
454                 engines[i].max -= engines[i].average;
455         }
456         if (!range) {
457                 g_timer_elapsed(timer, &microsec);
458                 msec = microsec / 100;
459                 g_message("Recognized -- No ratings, %dms", msec);
460                 input->ch = 0;
461                 return;
462         }
463
464         /* Rank the top samples */
465         alts[0] = NULL;
466         sampleiter_reset();
467         while ((sample = sampleiter_next())) {
468                 int j;
469
470                 sample_rating(sample);
471                 if (sample->rating < 1)
472                         continue;
473
474                 /* Bubble-sort the new rating in */
475                 for (j = 0; j < num_alts; j++)
476                         if (!alts[j]) {
477                                 if (j < num_alts - 1)
478                                         alts[j + 1] = NULL;
479                                 break;
480                         } else if (alts[j]->ch == sample->ch) {
481                                 if (alts[j]->rating >= sample->rating)
482                                         j = num_alts;
483                                 break;
484                         } else if (alts[j]->rating < sample->rating) {
485                                 int k;
486
487                                 if (j == num_alts - 1)
488                                         break;
489
490                                 /* See if the character is in the list */
491                                 for (k = j + 1; k < num_alts - 1 && alts[k] &&
492                                      alts[k]->ch != sample->ch; k++);
493
494                                 /* Do not swallow zeroes */
495                                 if (!alts[k] && k < num_alts - 1)
496                                         alts[k + 1] = NULL;
497
498                                 memmove(alts + j + 1, alts + j,
499                                         sizeof (*alts) * (k - j));
500                                 break;
501                         }
502                 if (j >= num_alts)
503                         continue;
504                 alts[j] = sample;
505         }
506
507         /* Normalize the alternates' accuracies to 100 */
508         if (range)
509                 for (i = 0; i < num_alts && alts[i]; i++)
510                         alts[i]->rating = alts[i]->rating * 100 / range;
511
512         /* Keep track of strength stat */
513         strength = 0;
514         if (alts[0]) {
515                 strength = alts[1] ? alts[0]->rating - alts[1]->rating :
516                                         100;
517                 strength_sum += strength;
518         }
519
520         g_timer_elapsed(timer, &microsec);
521         msec = microsec / 100;
522         g_message("Recognized -- %d/%d (%d%%) disqualified, "
523                   "%dms (%dms/symbol), %d%% strong",
524                   num_disqualified, prep_examined,
525                   num_disqualified * 100 / prep_examined, msec,
526                   prep_examined - num_disqualified ?
527                   msec / (prep_examined - num_disqualified) : -1,
528                   strength);
529
530         /*  Print out the top candidate scores in detail */
531         if (log_level >= G_LOG_LEVEL_DEBUG)
532                 for (i = 0; i < num_alts && alts[i]; i++) {
533                         int j, len;
534
535                         len = input->len >= alts[i]->len ? input->len :
536                                                            alts[i]->len;
537                         log_print("| '%C' (", alts[i]->ch);
538                         for (j = 0; j < ENGINES; j++)
539                                 log_print("%4d [%5d]%s",
540                                         engine_rating(alts[i], j),
541                                         alts[i]->ratings[j],
542                                         j < ENGINES - 1 ? "," : "");
543                         log_print(") %3d%% [", alts[i]->rating);
544                         for (j = 0; j < len; j++)
545                                 log_print("%d",
546                                           alts[i]->transform.order[j] - 1);
547                         for (j = 0; j < len; j++)
548                                 log_print("%c", alts[i]->transform.reverse[j] ?
549                                                 'R' : '-');
550                         for (j = 0; j < len; j++)
551                                 log_print("%d", alts[i]->transform.glue[j]);
552                         log_print("]\n");
553                 }
554
555         /* Select the top result */
556         input->ch = alts[0] ? alts[0]->ch : 0;
557 }
558
559 static void insert_sample(const Sample *new_sample, int force_overwrite)
560 /* Insert a sample into the sample chain, possibly overwriting an older
561    sample */
562 {
563         int last_used, count = 0;
564         Sample *sample, *overwrite = NULL, *create = NULL;
565
566         last_used = force_overwrite ? current + 1 : new_sample->used;
567         sampleiter_reset();
568         while ((sample = sampleiter_next())) {
569                 if (!sample->used) {
570                         create = sample;
571                         continue;
572                 }
573                 if (sample->ch != new_sample->ch)
574                         continue;
575                 if (sample->used < last_used) {
576                         overwrite = sample;
577                         last_used = sample->used;
578                 }
579                 count++;
580         }
581         if (overwrite && count >= samples_max) {
582                 sample = overwrite;
583                 clear_sample(sample);
584         } else if (create)
585                 sample = create;
586         else
587                 sample = sample_new();
588         *sample = *new_sample;
589         process_sample(sample);
590 }
591
592 void train_sample(const Sample *sample, int trusted)
593 /* Overwrite a blank or least-recently-used slot in the samples set */
594 {
595         Sample new_sample;
596
597         /* Do not allow zero-length samples */
598         if (sample->len < 1) {
599                 g_warning("Attempted to train zero length sample for '%C'",
600                           sample->ch);
601                 return;
602         }
603
604         copy_sample(&new_sample, sample);
605         new_sample.used = trusted ? current++ : 1;
606         new_sample.enabled = TRUE;
607         insert_sample(&new_sample, TRUE);
608 }
609
610 int char_trained(int ch)
611 /* Count the number of samples for this character */
612 {
613         Sample *sample;
614         int count = 0;
615
616         sampleiter_reset();
617         while ((sample = sampleiter_next())) {
618                 if (sample->ch != ch)
619                         continue;
620                 count++;
621         }
622         return count;
623 }
624
625 void untrain_char(int ch)
626 /* Delete all samples for a character */
627 {
628         Sample *sample;
629
630         sampleiter_reset();
631         while ((sample = sampleiter_next()))
632                 if (sample->ch == ch)
633                         clear_sample(sample);
634 }
635
636 /*
637         Profile
638 */
639
640 void recognize_sync(void)
641 /* Sync params with the profile */
642 {
643         int i;
644
645         profile_write("recognize");
646         profile_sync_int(&current);
647         profile_sync_int(&samples_max);
648         if (samples_max < 1)
649                 samples_max = 1;
650         profile_sync_int(&no_latin_alpha);
651         for (i = 0; i < ENGINES; i++)
652                 profile_sync_int(&engines[i].range);
653         profile_write("\n");
654 }
655
656 void sample_read(void)
657 /* Read a sample from the profile */
658 {
659         Sample sample;
660         Stroke *stroke;
661
662         memset(&sample, 0, sizeof (sample));
663         sample.ch = atoi(profile_read());
664         if (!sample.ch) {
665                 g_warning("Sample on line %d has NULL symbol", profile_line);
666                 return;
667         }
668         sample.used = atoi(profile_read());
669         stroke = sample.strokes[0];
670         for (;;) {
671                 const char *str;
672                 int x, y;
673
674                 str = profile_read();
675                 if (!str[0]) {
676                         if (!sample.strokes[0]) {
677                                 g_warning("Sample on line %d ('%C') with no "
678                                           "point data", profile_line,
679                                           sample.ch);
680                                 break;
681                         }
682                         insert_sample(&sample, FALSE);
683                         break;
684                 }
685                 if (str[0] == ';') {
686                         stroke = sample.strokes[sample.len];
687                         continue;
688                 }
689                 if (sample.len >= STROKES_MAX) {
690                         g_warning("Sample on line %d ('%C') is oversize",
691                                   profile_line, sample.ch);
692                         clear_sample(&sample);
693                         break;
694                 }
695                 if (!stroke) {
696                         stroke = stroke_new(0);
697                         sample.strokes[sample.len++] = stroke;
698                 }
699                 if (stroke->len >= POINTS_MAX) {
700                         g_warning("Symbol '%C' stroke %d is oversize",
701                                   sample.ch, sample.len);
702                         clear_sample(&sample);
703                         break;
704                 }
705                 x = atoi(str);
706                 y = atoi(profile_read());
707                 draw_stroke(&stroke, x, y);
708         }
709 }
710
711 static void sample_write(Sample *sample)
712 /* Write a sample link to the profile */
713 {
714         int k, l;
715
716         profile_write(va("sample %5d %5d", sample->ch, sample->used));
717         for (k = 0; k < sample->len; k++) {
718                 for (l = 0; l < sample->strokes[k]->len; l++)
719                         profile_write(va(" %4d %4d",
720                                          sample->strokes[k]->points[l].x,
721                                          sample->strokes[k]->points[l].y));
722                 profile_write("    ;");
723         }
724         profile_write("\n");
725 }
726
727 void samples_write(void)
728 /* Write all of the samples to the profile */
729 {
730         Sample *sample;
731
732         sampleiter_reset();
733         while ((sample = sampleiter_next()))
734                 if (sample->ch && sample->used)
735                         sample_write(sample);
736 }
737