/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.trees.ud;

import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher;
import edu.stanford.nlp.semgraph.semgrex.SemgrexPattern;
import edu.stanford.nlp.trees.GrammaticalRelation;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalRelations;
import edu.stanford.nlp.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Pattern;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class UniversalGappingEnhancer {
    private static Embedding embeddings;
    private static double GAP_PENALTY;
    private static double POS_MISMATCH_PENALTY;
    private static double EDGE_WEIGHT;
    private static final HashMap<String, String> coarserUPOSMap;
    private static final SemgrexPattern ORPHAN_PATTERN;
    private static final Pattern ARGUMENT_PATTERN;
    private static final Pattern CLAUSAL_ARGUMENT_PATTERN;
    private static final Pattern MODIFIER_PATTERN;
    private static final Pattern CORE_ARGUMENTS_PATTERN;
    private static final SemgrexPattern CONJ_PATTERN;

    private static final String coarsenUPOSTag(String uPOS) {
        if (coarserUPOSMap.containsKey(uPOS)) {
            return coarserUPOSMap.get(uPOS);
        }
        return uPOS;
    }

    private static final Pair<Double, List<Integer>> align(List<ArgumentSequence> fullArguments, List<ArgumentSequence> gappedArguments) {
        int j;
        int i;
        int n = fullArguments.size();
        int m = gappedArguments.size();
        double[][] scores = new double[n + 1][m + 1];
        int[][][] backtracing = new int[n + 1][m + 1][2];
        for (i = 0; i < n + 1; ++i) {
            scores[i][0] = (double)i * GAP_PENALTY;
            backtracing[i][0][0] = i - 1;
            backtracing[i][0][1] = 0;
        }
        for (int j2 = 0; j2 < m + 1; ++j2) {
            scores[0][j2] = (double)j2 * GAP_PENALTY;
            backtracing[0][j2][0] = 0;
            backtracing[0][j2][1] = j2 - 1;
        }
        for (i = 1; i < n + 1; ++i) {
            for (j = 1; j < m + 1; ++j) {
                String gappedCoarseTag;
                String fullCoarseTag;
                double distance = 0.0;
                if (embeddings != null) {
                    SimpleMatrix fullEmbedding = fullArguments.get(i - 1).getAverageEmbeddings();
                    SimpleMatrix gappedEmbedding = gappedArguments.get(j - 1).getAverageEmbeddings();
                    distance = ((SimpleMatrix)fullEmbedding.minus((SimpleBase)gappedEmbedding)).normF();
                }
                double posScore = (fullCoarseTag = UniversalGappingEnhancer.coarsenUPOSTag((String)fullArguments.get((int)(i - 1)).head.get(CoreAnnotations.CoarseTagAnnotation.class))).equals(gappedCoarseTag = UniversalGappingEnhancer.coarsenUPOSTag((String)gappedArguments.get((int)(j - 1)).head.get(CoreAnnotations.CoarseTagAnnotation.class))) ? 0.0 : POS_MISMATCH_PENALTY;
                double match = scores[i - 1][j - 1] - distance + posScore;
                double del = scores[i - 1][j] + GAP_PENALTY;
                double ins = scores[i][j - 1] + GAP_PENALTY;
                backtracing[i][j][0] = match >= del && match >= ins ? i - 1 : (del > match && del >= ins ? i - 1 : i);
                backtracing[i][j][1] = match >= del && match >= ins ? j - 1 : (del > match && del >= ins ? j : j - 1);
                scores[i][j] = Math.max(match, Math.max(del, ins));
            }
        }
        i = n;
        j = m;
        LinkedList<Integer> alignmentA = new LinkedList<Integer>();
        LinkedList<Integer> alignmentB = new LinkedList<Integer>();
        while (i > 0 || j > 0) {
            int new_i = backtracing[i][j][0];
            int new_j = backtracing[i][j][1];
            if (new_i == i - 1 && new_j == j - 1) {
                alignmentA.add(new_i);
                alignmentB.add(new_j);
            } else if (new_i == i - 1 && new_j == j) {
                alignmentA.add(new_i);
                alignmentB.add(-1);
            } else {
                alignmentA.add(-1);
                alignmentB.add(new_j);
            }
            i = new_i;
            j = new_j;
        }
        Collections.reverse(alignmentA);
        Collections.reverse(alignmentB);
        double alignmentScore = scores[n][m];
        ArrayList alignment = new ArrayList(m);
        for (int k = 0; k < alignmentB.size(); ++k) {
            if ((Integer)alignmentB.get(k) <= -1) continue;
            alignment.add(alignmentA.get(k));
        }
        Pair<Double, List<Integer>> result = new Pair<Double, List<Integer>>(alignmentScore, alignment);
        return result;
    }

    private static final Pair<IndexedWord, IndexedWord> getConjGovOrphanGovPair(SemanticGraph sg) {
        SemgrexMatcher matcher = ORPHAN_PATTERN.matcher(sg);
        int conjGovPosition = Integer.MAX_VALUE;
        IndexedWord firstConjGov = null;
        IndexedWord firstOrphanGov = null;
        while (matcher.find()) {
            IndexedWord conjGov = matcher.getNode("conjgov");
            IndexedWord orphanGov = matcher.getNode("orphangov");
            if (firstOrphanGov != null && firstOrphanGov.index() <= orphanGov.index()) continue;
            firstConjGov = conjGov;
            firstOrphanGov = orphanGov;
        }
        if (firstOrphanGov != null) {
            return new Pair<Object, Object>(firstConjGov, firstOrphanGov);
        }
        return null;
    }

    private static final boolean isArgument(SemanticGraph sg, SemanticGraphEdge edge) {
        boolean matches = ARGUMENT_PATTERN.matcher(edge.getRelation().toString()).matches();
        if (matches) {
            for (SemanticGraphEdge edge2 : sg.outgoingEdgeIterable(edge.getDependent())) {
                if (!edge2.getRelation().equals(UniversalEnglishGrammaticalRelations.ORPHAN)) continue;
                return false;
            }
            return true;
        }
        return false;
    }

    private static final boolean isClausalArgument(SemanticGraph sg, SemanticGraphEdge edge) {
        boolean matches = CLAUSAL_ARGUMENT_PATTERN.matcher(edge.getRelation().toString()).matches();
        if (matches) {
            for (SemanticGraphEdge edge2 : sg.outgoingEdgeIterable(edge.getDependent())) {
                if (!edge2.getRelation().equals(UniversalEnglishGrammaticalRelations.ORPHAN)) continue;
                return false;
            }
            return true;
        }
        return false;
    }

    private static final void getArgumentSubsequences(SemanticGraph sg, IndexedWord currentHead, List<ArgumentSequence> currentSequences) {
        for (SemanticGraphEdge edge : sg.outgoingEdgeIterable(currentHead)) {
            if (!UniversalGappingEnhancer.isArgument(sg, edge)) continue;
            ArgumentSequence seq = new ArgumentSequence(edge.getDependent(), sg.yield(edge.getDependent()));
            currentSequences.add(seq);
            if (!UniversalGappingEnhancer.isClausalArgument(sg, edge)) continue;
            UniversalGappingEnhancer.getArgumentSubsequences(sg, edge.getDependent(), currentSequences);
        }
    }

    private static final List<List<ArgumentSequence>> getFullConjunctArgumentsHelper(SemanticGraph sg, IndexedWord conjGov, IndexedWord orphanGov) {
        LinkedList<List<ArgumentSequence>> arguments = new LinkedList<List<ArgumentSequence>>();
        for (SemanticGraphEdge edge : sg.outgoingEdgeIterable(conjGov)) {
            if (!UniversalGappingEnhancer.isArgument(sg, edge) || !(edge.getDependent().pseudoPosition() < orphanGov.pseudoPosition())) continue;
            LinkedList<ArgumentSequence> argumentVariants = new LinkedList<ArgumentSequence>();
            ArgumentSequence seq = new ArgumentSequence(edge.getDependent(), sg.yield(edge.getDependent()));
            argumentVariants.add(seq);
            UniversalGappingEnhancer.getArgumentSubsequences(sg, edge.getDependent(), argumentVariants);
            arguments.add(argumentVariants);
        }
        arguments.sort((arg1, arg2) -> ((ArgumentSequence)arg1.get((int)0)).head.index() - ((ArgumentSequence)arg2.get((int)0)).head.index());
        return arguments;
    }

    private static final void buildAllArgumentSequences(int argIndex, List<ArgumentSequence> prefix, List<List<ArgumentSequence>> argumentVariants, List<List<ArgumentSequence>> currentSequences) {
        int nArguments = argumentVariants.size();
        for (ArgumentSequence seq : argumentVariants.get(argIndex)) {
            ArrayList<ArgumentSequence> newPrefix = new ArrayList<ArgumentSequence>(nArguments);
            newPrefix.addAll(prefix);
            newPrefix.add(seq);
            if (nArguments == argIndex + 1) {
                currentSequences.add(newPrefix);
                continue;
            }
            UniversalGappingEnhancer.buildAllArgumentSequences(argIndex + 1, newPrefix, argumentVariants, currentSequences);
        }
    }

    private static final List<List<ArgumentSequence>> getFullConjunctArguments(SemanticGraph sg, IndexedWord conjGov, IndexedWord orphanGov) {
        List<List<ArgumentSequence>> argumentVariants = UniversalGappingEnhancer.getFullConjunctArgumentsHelper(sg, conjGov, orphanGov);
        int totalArguments = argumentVariants.size() > 0 ? 1 : 0;
        for (List<ArgumentSequence> args : argumentVariants) {
            totalArguments *= args.size();
        }
        ArrayList<List<ArgumentSequence>> argumentSequences = new ArrayList<List<ArgumentSequence>>(totalArguments);
        if (totalArguments > 0) {
            UniversalGappingEnhancer.buildAllArgumentSequences(0, new LinkedList<ArgumentSequence>(), argumentVariants, argumentSequences);
        }
        return argumentSequences;
    }

    private static final boolean isModifier(SemanticGraphEdge edge) {
        return MODIFIER_PATTERN.matcher(edge.getRelation().toString()).matches();
    }

    private static final ArgumentSequence getOrphanGovSequence(SemanticGraph sg, IndexedWord orphanGov) {
        LinkedList<IndexedWord> seq = new LinkedList<IndexedWord>();
        seq.add(orphanGov);
        for (SemanticGraphEdge edge : sg.outgoingEdgeIterable(orphanGov)) {
            if (!UniversalGappingEnhancer.isModifier(edge)) continue;
            seq.addAll(sg.yield(edge.getDependent()));
        }
        Collections.sort(seq);
        return new ArgumentSequence(orphanGov, seq);
    }

    private static final List<ArgumentSequence> getGappedConjunctArguments(SemanticGraph sg, IndexedWord orphanGov) {
        LinkedList<ArgumentSequence> arguments = new LinkedList<ArgumentSequence>();
        for (SemanticGraphEdge edge : sg.outgoingEdgeIterable(orphanGov)) {
            if (!edge.getRelation().equals(UniversalEnglishGrammaticalRelations.ORPHAN)) continue;
            ArgumentSequence seq = new ArgumentSequence(edge.getDependent(), sg.yield(edge.getDependent()));
            arguments.add(seq);
        }
        arguments.add(UniversalGappingEnhancer.getOrphanGovSequence(sg, orphanGov));
        arguments.sort((arg1, arg2) -> arg1.head.compareTo(arg2.head));
        return arguments;
    }

    private static final void doEnhancement(SemanticGraph sg, IndexedWord conjGov, IndexedWord orphanGov, List<ArgumentSequence> fullConjunctArguments, List<ArgumentSequence> orphanConjunctArguments, List<Integer> alignment) {
        HashMap<IndexedWord, IndexedWord> copiedNodes = new HashMap<IndexedWord, IndexedWord>();
        IndexedWord conjGovCopy = conjGov.makeSoftCopy();
        conjGovCopy.setPseudoPosition(conjGovCopy.pseudoPosition() + (double)conjGovCopy.copyCount() / 10.0);
        SemanticGraphEdge edge = sg.getEdge(conjGov, orphanGov);
        sg.removeEdge(edge);
        sg.addEdge(conjGov, conjGovCopy, edge.getRelation(), EDGE_WEIGHT, false);
        copiedNodes.put(conjGov, conjGovCopy);
        for (int i = 0; i < orphanConjunctArguments.size(); ++i) {
            int alignmentIdx;
            IndexedWord dep = orphanConjunctArguments.get((int)i).head;
            if (sg.hasParentWithReln(dep, UniversalEnglishGrammaticalRelations.ORPHAN)) {
                SemanticGraphEdge oldEdge = sg.getEdge(orphanGov, dep);
                sg.removeEdge(oldEdge);
            }
            if ((alignmentIdx = alignment.get(i).intValue()) < 0) {
                sg.addEdge(conjGovCopy, dep, GrammaticalRelation.DEPENDENT, EDGE_WEIGHT, false);
                continue;
            }
            IndexedWord parallelArgument = fullConjunctArguments.get((int)alignmentIdx).head;
            List<SemanticGraphEdge> parallelPath = sg.getShortestDirectedPathEdges(conjGov, parallelArgument);
            for (int j = 0; j < parallelPath.size(); ++j) {
                IndexedWord targetNode;
                SemanticGraphEdge parallelEdge = parallelPath.get(j);
                boolean newCopyNode = false;
                IndexedWord sourceNode = (IndexedWord)copiedNodes.get(parallelEdge.getGovernor());
                if (sourceNode == null) {
                    IndexedWord copyNode = parallelEdge.getGovernor().makeSoftCopy();
                    copyNode.setPseudoPosition(copyNode.pseudoPosition() + (double)copyNode.copyCount() / 10.0);
                    copiedNodes.put(parallelEdge.getGovernor(), copyNode);
                    newCopyNode = true;
                    sourceNode = copyNode;
                }
                IndexedWord indexedWord = targetNode = j < parallelPath.size() - 1 ? (IndexedWord)copiedNodes.get(parallelEdge.getDependent()) : dep;
                if (targetNode == null) {
                    IndexedWord copyNode = parallelEdge.getDependent().makeSoftCopy();
                    copyNode.setPseudoPosition(copyNode.pseudoPosition() + (double)copyNode.copyCount() / 10.0);
                    copiedNodes.put(parallelEdge.getDependent(), copyNode);
                    newCopyNode = true;
                    targetNode = copyNode;
                }
                if (targetNode != dep && !newCopyNode) continue;
                sg.addEdge(sourceNode, targetNode, parallelEdge.getRelation(), EDGE_WEIGHT, false);
            }
        }
        for (Object copiedNode : copiedNodes.keySet()) {
            for (SemanticGraphEdge originalEdge : sg.outgoingEdgeIterable((IndexedWord)copiedNode)) {
                IndexedWord copyNode;
                if (!CORE_ARGUMENTS_PATTERN.matcher(originalEdge.getRelation().toString()).matches() || sg.hasChildWithReln(copyNode = (IndexedWord)copiedNodes.get(copiedNode), originalEdge.getRelation())) continue;
                sg.addEdge(copyNode, originalEdge.getDependent(), originalEdge.getRelation(), EDGE_WEIGHT, false);
            }
        }
        SemanticGraph sgCopy = sg.makeSoftCopy();
        for (IndexedWord copyNode : copiedNodes.values()) {
            SemgrexMatcher matcher = CONJ_PATTERN.matcher(sgCopy, copyNode);
            while (matcher.find()) {
                IndexedWord pred = matcher.getNode("predicate");
                IndexedWord arg1 = matcher.getNode("arg1");
                IndexedWord conjDep = matcher.getNode("conjdep");
                IndexedWord arg2 = matcher.getNode("arg2");
                if (pred != copyNode || arg1 != orphanGov || !(arg2.pseudoPosition() > arg1.pseudoPosition()) || !(conjDep.pseudoPosition() > arg2.pseudoPosition())) continue;
                SemanticGraphEdge conjEdge = sg.getEdge(arg1, conjDep);
                sg.removeEdge(conjEdge);
                sg.addEdge(pred, conjDep, UniversalEnglishGrammaticalRelations.CONJUNCT, EDGE_WEIGHT, false);
            }
        }
        List<SemanticGraphEdge> orphanOutgoingEdges = sg.outgoingEdgeList(orphanGov);
        for (SemanticGraphEdge edge1 : orphanOutgoingEdges) {
            if (!edge1.getRelation().getShortName().equals("cc")) continue;
            sg.removeEdge(edge1);
            sg.addEdge(conjGovCopy, edge1.getDependent(), edge1.getRelation(), EDGE_WEIGHT, false);
        }
    }

    public static final void addEnhancements(SemanticGraph sg, Embedding embeddingMatrix) {
        embeddings = embeddingMatrix;
        Pair<IndexedWord, IndexedWord> conjGovOrphanGov = null;
        int iterations = 0;
        while ((conjGovOrphanGov = UniversalGappingEnhancer.getConjGovOrphanGovPair(sg)) != null && ++iterations < 10) {
            IndexedWord conjGov = conjGovOrphanGov.first();
            IndexedWord orphanGov = conjGovOrphanGov.second();
            List<List<ArgumentSequence>> fullConjunctArgumentSequences = UniversalGappingEnhancer.getFullConjunctArguments(sg, conjGov, orphanGov);
            List<ArgumentSequence> gappedConjunctArguments = UniversalGappingEnhancer.getGappedConjunctArguments(sg, orphanGov);
            List<Integer> bestAlignment = null;
            List<ArgumentSequence> bestArgumentSequence = null;
            Double bestScore = Double.NEGATIVE_INFINITY;
            for (List<ArgumentSequence> fullConjunctArguments : fullConjunctArgumentSequences) {
                Pair<Double, List<Integer>> res = UniversalGappingEnhancer.align(fullConjunctArguments, gappedConjunctArguments);
                double score = (Double)res.first;
                List<Integer> alignment = res.second();
                if (!(score > bestScore)) continue;
                bestScore = score;
                bestAlignment = alignment;
                bestArgumentSequence = fullConjunctArguments;
            }
            if (bestArgumentSequence == null) continue;
            UniversalGappingEnhancer.doEnhancement(sg, conjGov, orphanGov, bestArgumentSequence, gappedConjunctArguments, bestAlignment);
        }
        if (iterations == 10) {
            System.err.println("Problem with graph:");
            System.err.println(sg.toString(SemanticGraph.OutputFormat.READABLE));
        }
    }

    static {
        GAP_PENALTY = -10.0;
        POS_MISMATCH_PENALTY = -2.0;
        EDGE_WEIGHT = Double.NEGATIVE_INFINITY;
        coarserUPOSMap = new HashMap(){
            {
                this.put("PROPN", "NOUN");
                this.put("PRON", "NOUN");
                this.put("NUM", "NOUN");
                this.put("DET", "NOUN");
            }
        };
        ORPHAN_PATTERN = SemgrexPattern.compile("{}=orphangov < {}=conjgov >orphan {}");
        ARGUMENT_PATTERN = Pattern.compile("^(i?obj|(n|c)subj.*|(x|c)comp|nmod(:tmod|:npadvmod)?|obl.*|advcl|acl|compound:prt)$");
        CLAUSAL_ARGUMENT_PATTERN = Pattern.compile("^(csubj.*|(x|c)comp|advcl|acl)$");
        MODIFIER_PATTERN = Pattern.compile("^(amod|advmod|nmod|obl|acl|mark|case|compound|flat)$");
        CORE_ARGUMENTS_PATTERN = Pattern.compile("^((n|c)subj.*|(x|c)comp|i?obj|expl|compound:prt)$");
        CONJ_PATTERN = SemgrexPattern.compile("{}=predicate > ({}=arg1 >conj {}=conjdep) > {}=arg2");
    }

    private static class ArgumentSequence {
        IndexedWord head;
        List<IndexedWord> sequence;

        private ArgumentSequence(IndexedWord gov, List<IndexedWord> seq) {
            this.head = gov;
            this.sequence = seq;
        }

        public String toString() {
            return this.sequence.toString();
        }

        private SimpleMatrix getAverageEmbeddings() {
            double[][] vec = new double[embeddings.getEmbeddingSize()][1];
            SimpleMatrix totalVector = new SimpleMatrix(vec);
            for (IndexedWord w : this.sequence) {
                SimpleMatrix vector = embeddings.get(w.word().toLowerCase());
                if (vector == null) continue;
                totalVector = (SimpleMatrix)totalVector.plus((SimpleBase)vector);
            }
            return (SimpleMatrix)totalVector.divide((double)this.sequence.size());
        }
    }
}

