/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.coref.neural;

import edu.stanford.nlp.coref.CorefAlgorithm;
import edu.stanford.nlp.coref.CorefProperties;
import edu.stanford.nlp.coref.CorefUtils;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.coref.neural.CategoricalFeatureExtractor;
import edu.stanford.nlp.coref.neural.EmbeddingExtractor;
import edu.stanford.nlp.coref.neural.NeuralCorefModel;
import edu.stanford.nlp.coref.neural.NeuralCorefProperties;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.ejml.simple.SimpleMatrix;

public class NeuralCorefAlgorithm
implements CorefAlgorithm {
    private static Redwood.RedwoodChannels log = Redwood.channels(NeuralCorefAlgorithm.class);
    private final double greedyness;
    private final int maxMentionDistance;
    private final int maxMentionDistanceWithStringMatch;
    private final CategoricalFeatureExtractor featureExtractor;
    private final EmbeddingExtractor embeddingExtractor;
    private final NeuralCorefModel model;

    public NeuralCorefAlgorithm(Properties props, Dictionaries dictionaries) {
        this.greedyness = NeuralCorefProperties.greedyness(props);
        this.maxMentionDistance = CorefProperties.maxMentionDistance(props);
        this.maxMentionDistanceWithStringMatch = CorefProperties.maxMentionDistanceWithStringMatch(props);
        this.model = (NeuralCorefModel)IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading coref model", NeuralCorefProperties.modelPath(props));
        this.embeddingExtractor = new EmbeddingExtractor(CorefProperties.conll(props), (Embedding)IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading coref embeddings", NeuralCorefProperties.pretrainedEmbeddingsPath(props)), this.model.getWordEmbeddings(), null);
        this.featureExtractor = new CategoricalFeatureExtractor(props, dictionaries);
    }

    @Override
    public void runCoref(Document document) {
        List<Mention> sortedMentions = CorefUtils.getSortedMentions(document);
        HashMap<Integer, List<Mention>> mentionsByHeadIndex = new HashMap<Integer, List<Mention>>();
        for (Mention m : sortedMentions) {
            List withIndex = mentionsByHeadIndex.computeIfAbsent(m.headIndex, k -> new ArrayList());
            withIndex.add(m);
        }
        SimpleMatrix documentEmbedding = this.embeddingExtractor.getDocumentEmbedding(document);
        HashMap<Integer, SimpleMatrix> antecedentEmbeddings = new HashMap<Integer, SimpleMatrix>();
        HashMap<Integer, SimpleMatrix> anaphorEmbeddings = new HashMap<Integer, SimpleMatrix>();
        ClassicCounter<Integer> anaphoricityScores = new ClassicCounter<Integer>();
        for (Mention m : sortedMentions) {
            SimpleMatrix mentionEmbedding = this.embeddingExtractor.getMentionEmbeddings(m, documentEmbedding);
            antecedentEmbeddings.put(m.mentionID, this.model.getAntecedentEmbedding(mentionEmbedding));
            anaphorEmbeddings.put(m.mentionID, this.model.getAnaphorEmbedding(mentionEmbedding));
            anaphoricityScores.incrementCount(m.mentionID, this.model.getAnaphoricityScore(mentionEmbedding, this.featureExtractor.getAnaphoricityFeatures(m, document, mentionsByHeadIndex)));
        }
        Map<Integer, List<Integer>> mentionToCandidateAntecedents = CorefUtils.heuristicFilter(sortedMentions, this.maxMentionDistance, this.maxMentionDistanceWithStringMatch);
        for (Map.Entry<Integer, List<Integer>> e : mentionToCandidateAntecedents.entrySet()) {
            double bestScore = anaphoricityScores.getCount(e.getKey()) - 50.0 * (this.greedyness - 0.5);
            int m = e.getKey();
            Integer antecedent = null;
            for (int ca : e.getValue()) {
                double score = this.model.getPairwiseScore((SimpleMatrix)antecedentEmbeddings.get(ca), (SimpleMatrix)anaphorEmbeddings.get(m), this.featureExtractor.getPairFeatures(new Pair<Integer, Integer>(ca, m), document, mentionsByHeadIndex));
                if (!(score > bestScore)) continue;
                bestScore = score;
                antecedent = ca;
            }
            if (antecedent == null) continue;
            CorefUtils.mergeCoreferenceClusters(new Pair<Object, Integer>(antecedent, m), document);
        }
    }
}

