package mpqa_seq_reranker; import se.lth.cs.nlp.nlputils.annotations.AnnotationLayer; import se.lth.cs.nlp.nlputils.annotations.Span; import se.lth.cs.nlp.nlputils.core.*; import se.lth.cs.nlp.nlputils.math.McNemarTest; import se.lth.cs.nlp.nlputils.ml_long.*; import java.util.*; import java.io.*; public class Reranker { static final boolean PRINT_TOP_FEATURES = false; //static final boolean ONLY_SUBJ = false; private static final double SCALE = 1.0; private static final double OFFSET = 0.0; private static final boolean F1_COST = true; public static void normalize(double[] scores, int length) { double m = CollectionUtils.max(scores, length); m = SCALE*m + OFFSET; double expSum = 0; for(int i = 0; i < length; i++) { scores[i] = SCALE*scores[i] + OFFSET; expSum += Math.exp(scores[i] - m); } double logExpSum = m + Math.log(expSum); for(int i = 0; i < length; i++) scores[i] -= logExpSum; } private static void saveTrainingSet(String fileName, RerankingFE fe, ArrayList l) throws IOException { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(fileName)); oos.writeObject(fe); oos.writeObject(l); oos.close(); } private static Pair>> loadTrainingSet(String fileName, int nsen) throws IOException { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(fileName)); try { RerankingFE fe = (RerankingFE) ois.readObject(); ArrayList> tmp = (ArrayList) ois.readObject(); ArrayList> out = new ArrayList(); int n = 0; for(Pair p: tmp) { out.add(p); n++; if(n == nsen) break; } return new Pair(fe, out); } catch(ClassNotFoundException e) { throw new IOException(e); } } private static int getMaxNCand(ArrayList> tr) { int max = 0; for(Pair p: tr) if(p.left.reps.length > max) max = p.left.reps.length; return max; } public static void train(String[] argv) { for(String av: argv) System.out.println(av); String mode = argv[1]; String model; String algName; String algArgs; ArrayList> ts; RerankingFE fe = null; int N; if(mode.equals("-compute")) { // alt 1: load normally String nbestFileName = argv[2]; String synsemFileName = argv[3]; N = Integer.parseInt(argv[4]); int nsen = Integer.parseInt(argv[5]); String cacheFileName = argv[6]; if(cacheFileName.equals("null") || cacheFileName.equals("NONE")) cacheFileName = null; boolean useBigrams = Boolean.parseBoolean(argv[7]); boolean useSyntax = Boolean.parseBoolean(argv[8]); boolean useSpans = Boolean.parseBoolean(argv[9]); boolean useSemantics = Boolean.parseBoolean(argv[10]); model = argv[11]; algName = argv[12]; algArgs = argv[13]; ts = new ArrayList(); try { fe = new RerankingFE(useBigrams, useSyntax, useSpans, useSemantics); System.out.println(fe); long t0 = System.currentTimeMillis(); int nTotal = 0; ArrayList sentences = SeqCandidateReader.readSentences(nbestFileName, synsemFileName, N, nsen); for(SentenceData sd: sentences) { nTotal++; if(nTotal % 100 == 0) { System.out.print("."); System.out.flush(); } ArrayList cands = sd.candidates; double[] costs = new double[cands.size()]; double[] lls = new double[cands.size()]; double minCost = Double.POSITIVE_INFINITY; double maxCost = Double.NEGATIVE_INFINITY; for(int i = 0; i < cands.size(); i++) { lls[i] = cands.get(i).baseScore; if(F1_COST) costs[i] = 1.0 - cands.get(i).stats.getPartialF1(); else costs[i] = cands.get(i).stats.getPropError(); if(costs[i] < minCost) minCost = costs[i]; if(costs[i] > maxCost) maxCost = costs[i]; } for(int i = 0; i < cands.size(); i++) costs[i] -= minCost; normalize(lls, cands.size()); int oraclePrediction = CollectionUtils.minIndex(costs, costs.length); if(minCost == maxCost) continue; NBestRepresentation nb = new NBestRepresentation(); nb.costs = costs; nb.reps = new SparseVector[cands.size()]; for(int i = 0; i < cands.size(); i++) { Candidate c = cands.get(i); SparseVector sv = new SparseVector(); fe.extractFeatures(c, sd.parse, lls[i], sv); nb.reps[i] = sv; } // subtract oracle prediction to save time and memory SparseVector op = nb.reps[oraclePrediction]; for(int i = 0; i < cands.size(); i++) { SparseVector sv = new SparseVector(); SparseVector.sortedSubtract(sv, nb.reps[i], op); sv.trim(); nb.reps[i] = sv; } Pair p = new Pair(nb, oraclePrediction); ts.add(p); } if(cacheFileName != null) saveTrainingSet(cacheFileName, fe, ts); long t1 = System.currentTimeMillis(); System.out.println("Preprocessing time: " + Util.toTimeString(t1 - t0, Util.ALL_TIME_FIELDS)); } catch(Exception e) { e.printStackTrace(); System.exit(1); } } else if(mode.equals("-loadPrecomputed")) { // alt 2: load precomputed training set try { int nsen = Integer.parseInt(argv[3]); Pair>> p = loadTrainingSet(argv[2], nsen); ts = p.right; fe = p.left; } catch(Exception e) { e.printStackTrace(); System.exit(1); ts = null; } N = getMaxNCand(ts); model = argv[4]; algName = argv[5]; algArgs = argv[6]; } else throw new IllegalArgumentException("unknown mode: " + mode); System.out.println(); AlgorithmFactory af = new AlgorithmFactory(); LearningAlgorithm alg = af.create(algName, algArgs); RerankingDefinition def = new RerankingDefinition(); Classifier cl = alg.train(def, ts); if(PRINT_TOP_FEATURES) { //LinkedList topFeatureIndices = def.getMaxFeaturesHack(); //RerankingFE.printProminentFeatures(topFeatureIndices); } try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(model)); oos.writeObject(fe); oos.writeObject(cl); oos.writeInt(N); oos.close(); } catch(Exception e) { e.printStackTrace(); System.exit(1); } System.out.println("Done."); } private static String spanString(Candidate c) { if(c.spans.isEmpty()) return "(EMPTY)"; StringBuilder sb = new StringBuilder(); for(Span s: c.spans) { if(sb.length() > 0) sb.append(","); sb.append("(" + s.label + "|" + s.tokenStart + "|" + s.tokenEnd + ")"); } return sb.toString(); } private static int checkOverlap(Span s, ArrayList ss) { for(Span s2: ss) if(s2.overlapsTokens(s)) return 2; for(Span s2: ss) if(s2.overlapsTokensUnlabeled(s)) return 1; return 0; } private static void checkChanges(SentenceData sd, int guess, Histogram hist) { if(guess == 0) return; ArrayList baseSpans = new ArrayList(); ArrayList guessSpans = new ArrayList(); for(Span s: sd.candidates.get(0).spans) { int ov = checkOverlap(s, sd.goldSpans); Span smod = new Span(); smod.label = s.label + ":" + ov; smod.tokenStart = s.tokenStart; smod.tokenEnd = s.tokenEnd; baseSpans.add(smod); } for(Span s: sd.candidates.get(guess).spans) { int ov = checkOverlap(s, sd.goldSpans); Span smod = new Span(); smod.label = s.label + ":" + ov; smod.tokenStart = s.tokenStart; smod.tokenEnd = s.tokenEnd; guessSpans.add(smod); } for(Span sb: baseSpans) { ArrayList matches = new ArrayList(); for(Span sg: guessSpans) { if(sb.overlapsTokensUnlabeled(sg)) matches.add(sg); } if(matches.isEmpty()) { hist.add(sb.label + " -> (EMPTY)"); } else { for(Span s: matches) hist.add(sb.label + " -> " + s.label); } } for(Span sg: guessSpans) { ArrayList matches = new ArrayList(); for(Span sb: baseSpans) { if(sb.overlapsTokensUnlabeled(sg)) matches.add(sb); } if(matches.isEmpty()) { hist.add("(EMPTY) -> " + sg.label); } } } public static ArrayList rerankSentence(RerankingFE fe, Classifier cl, ArrayList> l, SynSemParse parse) { //SentenceData sd = new SentenceData(); double[] lls = new double[l.size()]; ArrayList cands = new ArrayList(); for(int i = 0; i < l.size(); i++) { DoubleObjPair dp = l.get(i); Candidate c = new Candidate(); //NOT USED c.baseScore = dp.left; // TODO varför borttagen??? c.spans = dp.right.spans; cands.add(c); lls[i] = c.baseScore; } normalize(lls, cands.size()); NBestRepresentation nb = new NBestRepresentation(); nb.reps = new SparseVector[cands.size()]; for(int i = 0; i < cands.size(); i++) { Candidate c = cands.get(i); SparseVector sv = new SparseVector(); fe.extractFeatures(c, parse, lls[i], sv); nb.reps[i] = sv; } int prediction = cl.classify(nb); return l.get(prediction).right.spans; } public static void rerankSentence(RerankingFE fe, Classifier cl, ArrayList cands, double[] scores, SynSemParse parse) { if(scores.length < cands.size()) throw new IllegalArgumentException("score vector too short"); double[] lls = new double[cands.size()]; for(int i = 0; i < lls.length; i++) lls[i] = cands.get(i).baseScore; normalize(lls, lls.length); NBestRepresentation nb = new NBestRepresentation(); nb.reps = new SparseVector[cands.size()]; for(int i = 0; i < cands.size(); i++) { Candidate c = cands.get(i); SparseVector sv = new SparseVector(); fe.extractFeatures(c, parse, lls[i], sv); nb.reps[i] = sv; } ArrayList range = new ArrayList(); for(int i = 0; i < lls.length; i++) range.add(i); cl.computeScoresRestricted(nb, range, scores); } static HashMap parseClusters(String clString) { if(Util.isNullFile(clString)) return null; HashMap out = new HashMap(); String[] cls = clString.split("[/,]"); for(String cl: cls) { cl = cl.trim(); String[] types = cl.split("[&\\+]"); for(String t: types) out.put(t.trim(), cl); } return out; } private static void run(String[] argv) { try { //System.out.println(Arrays.toString(argv)); //PrintWriter debugPW = new PrintWriter("debug.xxx"); String nbestFileName = argv[1]; String synsemFileName = argv[2]; int N = Integer.parseInt(argv[3]); //boolean onlySubj = Boolean.parseBoolean(argv[4]); String clusterString = argv[4]; String model = argv[5]; String detailsFile = null; if(argv.length > 6) { detailsFile = argv[6]; if(detailsFile.toLowerCase().matches("null|none")) detailsFile = null; } int nTotal = 0; double totalFirstCost = 0; int nCorrectSelections = 0, nBaselineCorrect = 0; double totalCost = 0; int nNonzero = 0; ObjectInputStream ois = new ObjectInputStream(Util.openFileStream(model)); RerankingFE fe = (RerankingFE) ois.readObject(); //RerankingFE fe = new RerankingFE(); // NOT stateless Classifier cl = (Classifier) ois.readObject(); if(N < 1) { N = ois.readInt(); System.out.println("Setting N to default: " + N); } ois.close(); ArrayList sentences = SeqCandidateReader.readSentences(nbestFileName, synsemFileName, N, -1); // Stats systemStats = new Stats(); // Stats baselineStats = new Stats(); // Stats oracleStats = new Stats(); PRFStats systemStats = new PRFStats(detailsFile != null); PRFStats baselineStats = new PRFStats(); PRFStats oracleStats = new PRFStats(); int[][] mcnemarMatrix = new int[2][2]; Histogram changeStats = new Histogram(); for(SentenceData sd: sentences) { ArrayList cands = sd.candidates; //debugPW.println(cands.get(0).stats.toRawString()); nTotal++; if(nTotal % 100 == 0) { System.out.print("."); System.out.flush(); } double[] costs = new double[cands.size()]; double[] lls = new double[cands.size()]; double minCost = Double.POSITIVE_INFINITY; double maxCost = Double.NEGATIVE_INFINITY; for(int i = 0; i < cands.size(); i++) { lls[i] = cands.get(i).baseScore; if(F1_COST) costs[i] = 1.0 - cands.get(i).stats.getPartialF1(); else costs[i] = cands.get(i).stats.getPropError(); if(costs[i] < minCost) minCost = costs[i]; if(costs[i] > maxCost) maxCost = costs[i]; } for(int i = 0; i < cands.size(); i++) costs[i] -= minCost; normalize(lls, cands.size()); int oraclePrediction = CollectionUtils.minIndex(costs, costs.length); NBestRepresentation nb = new NBestRepresentation(); nb.costs = costs; nb.reps = new SparseVector[cands.size()]; for(int i = 0; i < cands.size(); i++) { Candidate c = cands.get(i); SparseVector sv = new SparseVector(); fe.extractFeatures(c, sd.parse, lls[i], sv); nb.reps[i] = sv; } //if(/*ONLY_SUBJ*/ onlySubj) // sd.simplifyStats(); if(!Util.isNullFile(clusterString)) { HashMap clusters = parseClusters(clusterString); sd.simplifyStats(clusters); } int prediction = cl.classify(nb); if(prediction > 0) nNonzero++; Candidate bc = sd.candidates.get(0); Candidate sc = sd.candidates.get(prediction); String endText = "\t" + spanString(sc) + "\t" + spanString(bc) + "\t" + sd.text; if(costs[0] == costs[prediction]) { if(costs[0] == 0) { mcnemarMatrix[0][0]++; //System.out.println("###\t(*)\t" + endText); } else { mcnemarMatrix[1][1]++; //System.out.println("###\t(_)\t" + endText); } } else { if(costs[0] < costs[prediction]) { mcnemarMatrix[0][1]++; //System.out.println("###\t(-)\t" + endText); } else { mcnemarMatrix[1][0]++; //System.out.println("###\t(+)\t" + endText); } } if(costs[0] == 0) nBaselineCorrect++; totalFirstCost += costs[0]; if(costs[prediction] == costs[oraclePrediction]) nCorrectSelections++; totalCost += costs[prediction]; systemStats.add(cands.get(prediction).stats); oracleStats.add(cands.get(oraclePrediction).stats); baselineStats.add(cands.get(0).stats); //systemStats.add(sd.goldSpans); //checkChanges(sd, prediction, changeStats); } //debugPW.close(); System.out.println(); double b_sa = (double) nBaselineCorrect / nTotal; double s_sa = (double) nCorrectSelections / nTotal; final boolean RAW = true; System.out.println("Baseline selection accuracy: " + b_sa); System.out.println("Baseline:"); baselineStats.print(); if(RAW) baselineStats.printRaw(); System.out.println(); System.out.println("Oracle:"); oracleStats.print(); if(RAW) oracleStats.printRaw(); System.out.println(); System.out.println("System selection accuracy: " + s_sa); System.out.println("System:"); systemStats.print(); if(RAW) systemStats.printRaw(); if(detailsFile != null) { System.out.println("Writing detailed results to " + detailsFile); PrintWriter resultSamples = new PrintWriter(new FileWriter(detailsFile)); systemStats.printSamples(resultSamples); resultSamples.close(); } /*System.out.println("-----"); int pi = 0; for(IntObjPair p: changeStats.asSortedList()) { pi++; System.out.println(pi + "\t" + p.left + "\t" + p.right); } */ // System.out.println(); // for(int i = 0; i < 2; i++) { // for(int j = 0; j < 2; j++) { // System.out.print(mcnemarMatrix[i][j] + " "); // } // System.out.println(); // } // // System.out.println(); // System.out.println("McNemar confidence (selection): " + McNemarTest.getConfidence(mcnemarMatrix)); //System.out.println(); //System.out.println(totalCost / nTotal); } catch(Exception e) { e.printStackTrace(); System.exit(1); } } public static void main(String[] argv) { if(argv[0].equals("-train")) train(argv); else if(argv[0].equals("-run")) run(argv); else throw new RuntimeException("unknown: " + Arrays.asList(argv)); } }