package nederhof.ocr.prob.experiments;

import com.itextpdf.text.html.HtmlTags;
import java.util.ArrayList;
import java.util.Iterator;
import nederhof.ocr.images.BinaryImage;
import nederhof.ocr.images.PartialBinaryImage;
import nederhof.ocr.prob.CrookedLine;
import nederhof.ocr.prob.ModelState;
import nederhof.ocr.prob.PixelInventory;
import nederhof.ocr.prob.ProtoPixel;
import nederhof.ocr.prob.ProtoPixelGain;
import nederhof.ocr.prob.SimpleDistModel;
import nederhof.ocr.prob.TextModel;
import nederhof.ocr.prob.UniLangModel;
import nederhof.ocr.prob.symbols.ProtoGlyph;
import nederhof.util.fsa.Fsa;
import nederhof.util.fsa.FsaBackward;
import nederhof.util.fsa.FsaForward;
import nederhof.util.fsa.FsaShortestPath;
import nederhof.util.fsa.FsaTrans;
import nederhof.util.math.NegLogProb;
import nederhof.util.math.SignedNegLogProb;

/* loaded from: input_file:nederhof/ocr/prob/experiments/Experiment.class */
public class Experiment {
    private BinaryImage stripedLineImage() {
        BinaryImage binaryImage = new BinaryImage(100, 10);
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 6; i2++) {
                binaryImage.set(i, i2, true);
            }
        }
        for (int i3 = 20; i3 < 30; i3++) {
            for (int i4 = 0; i4 < 6; i4++) {
                binaryImage.set(i3, i4, true);
            }
        }
        for (int i5 = 42; i5 < 50; i5++) {
            for (int i6 = 3; i6 < 8; i6++) {
                binaryImage.set(i5, i6, true);
            }
        }
        for (int i7 = 55; i7 < 70; i7++) {
            for (int i8 = 2; i8 < 4; i8++) {
                binaryImage.set(i7, i8, true);
            }
            for (int i9 = 8; i9 < 10; i9++) {
                binaryImage.set(i7, i9, true);
            }
        }
        for (int i10 = 95; i10 < 100; i10++) {
            for (int i11 = 3; i11 < 4; i11++) {
                binaryImage.set(i10, i11, true);
            }
        }
        return binaryImage;
    }

    private BinaryImage shortLineImage() {
        BinaryImage binaryImage = new BinaryImage(30, 10);
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 6; i2++) {
                binaryImage.set(i, i2, true);
            }
        }
        for (int i3 = 20; i3 < 30; i3++) {
            for (int i4 = 6; i4 < 10; i4++) {
                binaryImage.set(i3, i4, true);
            }
        }
        return binaryImage;
    }

    private CrookedLine stripedLine() {
        BinaryImage stripedLineImage = stripedLineImage();
        ArrayList arrayList = new ArrayList(stripedLineImage.width());
        for (int i = 0; i < stripedLineImage.width(); i++) {
            arrayList.add(Integer.valueOf(stripedLineImage.height()));
        }
        return new CrookedLine(stripedLineImage, arrayList, stripedLineImage.height(), 0);
    }

    private CrookedLine shortLine() {
        BinaryImage shortLineImage = shortLineImage();
        ArrayList arrayList = new ArrayList(shortLineImage.width());
        for (int i = 0; i < shortLineImage.width(); i++) {
            arrayList.add(Integer.valueOf(shortLineImage.height()));
        }
        return new CrookedLine(shortLineImage, arrayList, shortLineImage.height(), 0);
    }

    private BinaryImage protoImageBlack() {
        BinaryImage binaryImage = new BinaryImage(10, 10);
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 10; i2++) {
                binaryImage.set(i, i2, true);
            }
        }
        return binaryImage;
    }

    private BinaryImage protoImageBlackHalfSize() {
        BinaryImage binaryImage = new BinaryImage(5, 10);
        for (int i = 0; i < 5; i++) {
            for (int i2 = 0; i2 < 10; i2++) {
                binaryImage.set(i, i2, true);
            }
        }
        return binaryImage;
    }

    private BinaryImage protoImageWhite() {
        BinaryImage binaryImage = new BinaryImage(10, 10);
        for (int i = 0; i < 10; i++) {
            binaryImage.set(0, i, true);
            binaryImage.set(9, i, true);
        }
        return binaryImage;
    }

    private BinaryImage protoImageWhiteHalfSize() {
        BinaryImage binaryImage = new BinaryImage(5, 10);
        for (int i = 0; i < 10; i++) {
            binaryImage.set(0, i, true);
            binaryImage.set(4, i, true);
        }
        return binaryImage;
    }

    private BinaryImage protoImageBlackWhite() {
        BinaryImage binaryImage = new BinaryImage(10, 10);
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 5; i2++) {
                binaryImage.set(i, i2, true);
            }
        }
        return binaryImage;
    }

    private ArrayList<ProtoGlyph> testGlyphs1() {
        ArrayList<ProtoGlyph> arrayList = new ArrayList<>();
        arrayList.add(new ProtoGlyph(HtmlTags.B, protoImageBlack(), 10));
        arrayList.add(new ProtoGlyph("b/2", protoImageBlackHalfSize(), 10));
        arrayList.add(new ProtoGlyph("w", protoImageWhite(), 10));
        arrayList.add(new ProtoGlyph("w/2", protoImageWhiteHalfSize(), 10));
        arrayList.add(new ProtoGlyph("b/w", protoImageBlackWhite(), 10));
        arrayList.add(ProtoGlyph.getSpace(10.0d, 5.0d, 3, 10));
        return arrayList;
    }

    private TextModel textModel1() {
        ArrayList<ProtoGlyph> testGlyphs1 = testGlyphs1();
        TextModel textModel = new TextModel();
        textModel.setGlyphs(testGlyphs1);
        textModel.setPixelErrorProb(0.1d);
        textModel.setDistModel(new SimpleDistModel(1.0d, 3.0d, -5, 5));
        textModel.setLangModel(new UniLangModel(testGlyphs1.size() + 1));
        return textModel;
    }

    private double basicTest() {
        Fsa<ModelState, Integer> createFsa = textModel1().createFsa(stripedLine());
        for (FsaTrans fsaTrans : new FsaShortestPath(createFsa).shortestPath()) {
            ModelState modelState = (ModelState) fsaTrans.fromState();
            int pos = modelState.getPos();
            ProtoGlyph glyph = modelState.getGlyph();
            ModelState modelState2 = (ModelState) fsaTrans.toState();
            int pos2 = modelState2.getPos();
            ProtoGlyph glyph2 = modelState2.getGlyph();
            int intValue = ((Integer) fsaTrans.label()).intValue();
            double weight = fsaTrans.weight();
            System.out.println(pos + " " + glyph.getName());
            System.out.println("    " + intValue + " " + weight);
            System.out.println("    " + pos2 + " " + glyph2.getName());
        }
        FsaForward fsaForward = new FsaForward(createFsa);
        System.out.println("SUM " + fsaForward.sum());
        return fsaForward.sum();
    }

    private double getLikelihood(CrookedLine crookedLine, TextModel textModel) {
        return new FsaForward(textModel.createFsa(crookedLine)).sum();
    }

    private void bruteForceTweaking() {
        CrookedLine stripedLine = stripedLine();
        TextModel textModel1 = textModel1();
        double likelihood = getLikelihood(stripedLine, textModel1);
        while (true) {
            System.out.println(likelihood);
            double d = likelihood;
            ProtoGlyph protoGlyph = null;
            int i = 0;
            int i2 = 0;
            Iterator<ProtoGlyph> it = textModel1.getGlyphs().iterator();
            while (it.hasNext()) {
                ProtoGlyph next = it.next();
                BinaryImage image = next.getImage();
                if (image != null) {
                    for (int i3 = 0; i3 < image.width(); i3++) {
                        for (int i4 = 0; i4 < image.height(); i4++) {
                            boolean z = image.get(i3, i4);
                            image.set(i3, i4, !z);
                            double likelihood2 = getLikelihood(stripedLine, textModel1);
                            if (likelihood2 < d) {
                                System.out.println(" " + likelihood2);
                                d = likelihood2;
                                protoGlyph = next;
                                i = i3;
                                i2 = i4;
                            }
                            image.set(i3, i4, z);
                        }
                    }
                }
            }
            if (d >= likelihood) {
                return;
            }
            BinaryImage image2 = protoGlyph.getImage();
            image2.set(i, i2, !image2.get(i, i2));
            likelihood = d;
            System.out.println("improved " + protoGlyph.getName() + " " + i + " " + i2);
        }
    }

    private void forwardBackwardTweaking() {
        CrookedLine shortLine = shortLine();
        TextModel textModel1 = textModel1();
        double likelihood = getLikelihood(shortLine, textModel1);
        while (true) {
            double d = likelihood;
            System.out.println(SignedNegLogProb.from(d));
            ProtoPixelGain protoPixelGain = new ProtoPixelGain();
            Fsa<ModelState, Integer> createFsa = textModel1.createFsa(shortLine);
            FsaForward fsaForward = new FsaForward(createFsa);
            FsaBackward fsaBackward = new FsaBackward(createFsa);
            for (ModelState modelState : createFsa.getStates()) {
                for (FsaTrans<ModelState, Integer> fsaTrans : createFsa.fromTransitions(modelState)) {
                    fsaTrans.fromState();
                    ModelState state = fsaTrans.toState();
                    int pos = modelState.getPos();
                    ProtoGlyph glyph = modelState.getGlyph();
                    int pos2 = state.getPos();
                    ProtoGlyph glyph2 = state.getGlyph();
                    double weight = fsaTrans.weight();
                    if (fsaForward.get(modelState) != null && fsaForward.get(modelState).doubleValue() != 1.7976931348623158E303d && fsaBackward.get(state) != null && ((Double) fsaBackward.get(state)).doubleValue() != 1.7976931348623158E303d) {
                        double mult = NegLogProb.mult(fsaForward.get(modelState).doubleValue(), ((Double) fsaBackward.get(state)).doubleValue());
                        System.out.println("=========================================================" + pos + " " + pos2);
                        recordTransGainEfficient(shortLine, textModel1, pos, pos2, glyph, glyph2, weight, mult, protoPixelGain);
                        System.out.println("---------------------------------------------------------");
                        recordTransGainInefficient(shortLine, textModel1, pos, pos2, glyph, glyph2, weight, mult, new ProtoPixelGain());
                    }
                }
            }
            if (!tryImproveModel(protoPixelGain)) {
                return;
            }
            double likelihood2 = getLikelihood(shortLine, textModel1);
            if (likelihood2 >= d) {
                tryImproveModel(protoPixelGain);
                return;
            }
            likelihood = likelihood2;
        }
    }

    private void recordTransGainEfficient(CrookedLine crookedLine, TextModel textModel, int i, int i2, ProtoGlyph protoGlyph, ProtoGlyph protoGlyph2, double d, double d2, ProtoPixelGain protoPixelGain) {
        double mult = NegLogProb.mult(d, d2);
        double invert = SignedNegLogProb.invert(NegLogProb.mult(mult, textModel.getPixelValidationScore()));
        double mult2 = NegLogProb.mult(mult, textModel.getPixelInvalidationScore());
        PixelInventory transitionInventory = textModel.transitionInventory(crookedLine, i, i2, protoGlyph, protoGlyph2);
        Iterator<ProtoPixel> it = transitionInventory.flipIsBetter.iterator();
        while (it.hasNext()) {
            ProtoPixel next = it.next();
            ProtoGlyph glyph = next.getGlyph();
            int x = next.getX();
            int y = next.getY();
            System.out.println("" + glyph.getName() + " " + x + " " + y + " " + invert);
            protoPixelGain.recordGain(glyph, x, y, invert);
        }
        Iterator<ProtoPixel> it2 = transitionInventory.flipIsWorse.iterator();
        while (it2.hasNext()) {
            ProtoPixel next2 = it2.next();
            ProtoGlyph glyph2 = next2.getGlyph();
            int x2 = next2.getX();
            int y2 = next2.getY();
            System.out.println("" + glyph2.getName() + " " + x2 + " " + y2 + " " + mult2);
            protoPixelGain.recordGain(glyph2, x2, y2, mult2);
        }
    }

    private void recordTransGainInefficient(CrookedLine crookedLine, TextModel textModel, int i, int i2, ProtoGlyph protoGlyph, ProtoGlyph protoGlyph2, double d, double d2, ProtoPixelGain protoPixelGain) {
        BinaryImage image = protoGlyph.getImage();
        if (image != null) {
            for (int leftWidth = PartialBinaryImage.leftWidth(image); leftWidth < image.width(); leftWidth++) {
                for (int i3 = 0; i3 < image.height(); i3++) {
                    boolean z = image.get(leftWidth, i3);
                    image.set(leftWidth, i3, !z);
                    recordGain(protoGlyph, leftWidth, i3, d, textModel.transitionScore(crookedLine, i, i2, new ProtoGlyph[]{protoGlyph}, protoGlyph2), d2, protoPixelGain);
                    image.set(leftWidth, i3, z);
                }
            }
        }
        BinaryImage image2 = protoGlyph2.getImage();
        if (image2 != null) {
            for (int i4 = 0; i4 < PartialBinaryImage.leftWidth(image2); i4++) {
                for (int i5 = 0; i5 < image2.height(); i5++) {
                    boolean z2 = image2.get(i4, i5);
                    image2.set(i4, i5, !z2);
                    recordGain(protoGlyph2, i4, i5, d, textModel.transitionScore(crookedLine, i, i2, new ProtoGlyph[]{protoGlyph}, protoGlyph2), d2, protoPixelGain);
                    image2.set(i4, i5, z2);
                }
            }
        }
    }

    private void recordGain(ProtoGlyph protoGlyph, int i, int i2, double d, double d2, double d3, ProtoPixelGain protoPixelGain) {
        double mult = SignedNegLogProb.mult(SignedNegLogProb.subtract(d, d2), d3);
        if (-10000.0d >= mult || mult >= 10000.0d) {
            return;
        }
        System.out.println("" + protoGlyph.getName() + " " + i + " " + i2 + " " + mult);
        protoPixelGain.recordGain(protoGlyph, i, i2, mult);
    }

    private boolean tryImproveModel(ProtoPixelGain protoPixelGain) {
        ArrayList<ProtoPixel> bestProtoPixels = protoPixelGain.bestProtoPixels();
        if (bestProtoPixels.size() == 0) {
            return false;
        }
        ProtoPixel protoPixel = bestProtoPixels.get(0);
        double diff = protoPixelGain.getDiff(protoPixel);
        System.out.println(SignedNegLogProb.from(diff));
        System.out.println(protoPixel.getGlyph().getName());
        System.out.println(protoPixel.getX());
        System.out.println(protoPixel.getY());
        if (diff >= 0.0d) {
            return false;
        }
        BinaryImage image = protoPixel.getGlyph().getImage();
        int x = protoPixel.getX();
        int y = protoPixel.getY();
        image.set(x, y, !image.get(x, y));
        return true;
    }

    public static void main(String[] strArr) {
        new Experiment().bruteForceTweaking();
    }
}
