package mpqa_seq_reranker; import java.util.*; import java.util.zip.GZIPInputStream; import java.io.*; import se.lth.cs.nlp.nlputils.core.*; import se.lth.cs.nlp.nlputils.depgraph.*; import se.lth.cs.nlp.nlputils.annotations.*; import se.lth.cs.nlp.depsrl.format.*; public class SeqCandidateReader { private static SentenceData readNextSentence(BufferedReader br, HashMap sm, int N) throws IOException { String text = br.readLine(); if(text == null) return null; boolean debug = text.equals("Now Mr. Ruckauf is considered a major presidential contender ."); SentenceData out = new SentenceData(); if(sm != null) { SynSemParse parse = sm.get(text); if(parse == null) throw new RuntimeException("No parse found for this sentence: |" + text + "|"); out.parse = parse; } String line = br.readLine().trim(); if(!line.equals("")) { String[] ss = line.split(" "); for(String s: ss) { String[] ss2 = s.split(","); Span sp = new Span(); sp.tokenStart = Integer.parseInt(ss2[1]); sp.tokenEnd = Integer.parseInt(ss2[2]); sp.label = ss2[0]; out.goldSpans.add(sp); } } int i = 0; line = br.readLine().trim(); while(!line.equals("---")) { if(Integer.parseInt(line) != i) throw new RuntimeException("wrong index, should be " + i); Candidate c = new Candidate(); c.baseScore = Double.parseDouble(br.readLine()); String[] ss = br.readLine().trim().split(" "); c.stats = new PRFStats(ss); line = br.readLine().trim(); if(!line.equals("")) { ss = line.split(" "); for(String s: ss) { String[] ss2 = s.split(","); Span sp = new Span(); sp.tokenStart = Integer.parseInt(ss2[1]); sp.tokenEnd = Integer.parseInt(ss2[2]); sp.label = ss2[0]; c.spans.add(sp); } } if(i < N) out.candidates.add(c); i++; line = br.readLine().trim(); } out.text = text; for(i = 0; i < out.candidates.size(); i++) { Candidate c1 = out.candidates.get(i); Iterator it = out.candidates.listIterator(i + 1); while(it.hasNext()) { Candidate c2 = it.next(); if(c1.spans.equals(c2.spans)) it.remove(); } } if(debug) { Candidate c0 = out.candidates.get(0); System.out.println(c0.stats.toRawString()); } return out; } private static void toSubjectivityLayer(AnnotationLayer l) { for(Iterator it = l.iterator(); it.hasNext(); ) { Span s = it.next(); if(s.label.equals("os")) it.remove(); else s.label = "ss"; } } private static void printStats(double[] stats, String prefix) { double nCorrectSpans = stats[0]; double nGuesses = stats[1]; double nInGold = stats[2]; double propGuessCorrect = stats[3]; double propFoundCorrect = stats[4]; double nOverlap = stats[5]; double pHard = nCorrectSpans / nGuesses; double rHard = nCorrectSpans / nInGold; double fHard = 2*pHard*rHard / (pHard + rHard); double pSoft = propGuessCorrect / nGuesses; double rSoft = propFoundCorrect / nInGold; double fSoft = 2*pSoft*rSoft / (pSoft + rSoft); double pOver = nOverlap / nGuesses; double rOver = nOverlap / nInGold; double fOver = 2*pOver*rOver / (pOver + rOver); System.err.println(prefix + " hard: p = " + pHard + ", r = " + rHard + ", f1 = " + fHard); System.err.println(prefix + " soft: p = " + pSoft + ", r = " + rSoft + ", f1 = " + fSoft); System.err.println(prefix + " overlap: p = " + pOver + ", r = " + rOver + ", f1 = " + fOver); } public static ArrayList readSentences(String candFileName, String synsemFileName, int N, int maxNSentences) throws IOException { HashMap senMap; if(synsemFileName.equals("NULL")) senMap = null; else { senMap = new HashMap(); Scanner ssc = synsemFileName.equals("NULL")? null: new Scanner(Util.openFileStream(synsemFileName)); Triple> g = CoNLL2008Format.readNextGraph(ssc); while(g != null) { SynSemParse ssp = new SynSemParse(); ssp.dg = g.first; ssp.pas = g.third; StringBuilder sb = new StringBuilder(); for(int i = 1; i < g.first.nodes.length; i++) { if(i > 1) sb.append(" "); sb.append(g.first.nodes[i].word); } senMap.put(sb.toString(), ssp); g = CoNLL2008Format.readNextGraph(ssc); } } System.out.println("Read trees."); ArrayList out = new ArrayList(); BufferedReader br = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(candFileName)))); SentenceData sd = readNextSentence(br, senMap, N); while(sd != null && out.size() != maxNSentences) { out.add(sd); sd = readNextSentence(br, senMap, N); } return out; } public static void main(String[] argv) { try { String mode = argv[0]; String candFileName = argv[1]; String synsemFileName = argv[2]; int N = Integer.parseInt(argv[3]); String outFileName = argv[4]; final boolean onlySubjectivity = true; // false; HashMap senMap = new HashMap(); Scanner ssc = new Scanner(new FileReader(synsemFileName)); Triple> g = CoNLL2008Format.readNextGraph(ssc); while(g != null) { SynSemParse ssp = new SynSemParse(); ssp.dg = g.first; ssp.pas = g.third; StringBuilder sb = new StringBuilder(); for(int i = 1; i < g.first.nodes.length; i++) { if(i > 1) sb.append(" "); sb.append(g.first.nodes[i].word); } senMap.put(sb.toString(), ssp); g = CoNLL2008Format.readNextGraph(ssc); } double[] baselineStats = new double[6]; double[] oracleStats = new double[6]; PrintWriter out = new PrintWriter(new FileWriter(outFileName)); BufferedReader br = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(candFileName)))); SentenceData sd = readNextSentence(br, senMap, N); while(sd != null) { // ... AnnotationLayer goldLayer = new AnnotationLayer(); goldLayer.spans = sd.goldSpans; if(onlySubjectivity) toSubjectivityLayer(goldLayer); AnnotationLayer baselineLayer = new AnnotationLayer(); baselineLayer.spans = sd.candidates.get(0).spans; if(onlySubjectivity) toSubjectivityLayer(baselineLayer); int bestIndex = -1; double minErr = Double.MAX_VALUE; for(int j = 0; j < sd.candidates.size(); j++) { Candidate c = sd.candidates.get(j); double err = c.stats.getPropError(); if(err < minErr) { minErr = err; bestIndex = j; } } AnnotationLayer oracleLayer = new AnnotationLayer(); oracleLayer.spans = sd.candidates.get(bestIndex).spans; if(onlySubjectivity) toSubjectivityLayer(oracleLayer); double[] tmp1 = goldLayer.compareApprox(baselineLayer); for(int i = 3; i < 9; i++) baselineStats[i-3] += tmp1[i]; double[] tmp2 = goldLayer.compareApprox(oracleLayer); for(int i = 3; i < 9; i++) oracleStats[i-3] += tmp2[i]; sd = readNextSentence(br, senMap, N); } printStats(baselineStats, "Baseline"); printStats(oracleStats, "Oracle"); } catch(Exception e) { e.printStackTrace(); } } }